目录
1. 可微物理模拟器(Differentiable Physics Simulators)
2. 神经闭包模型(Neural Closure Models)
一、JAX + Haiku 版本 (jax_diffusive_closure.py)
二、Taichi 版本 (taichi_example.py)
(Symbiosis of AI and Traditional Numerical Methods)
本章聚焦于把 AI 模型作为 “智能组件” 嵌入成熟的 HPC 工作流,而不是全盘替代传统求解器。
通过可微物理模拟器、神经闭包模型以及 AI‑加速迭代求解器三大技术路线,阐明它们如何在数学层面实现互补,并给出关键的理论分析与数值稳定性保证。
1. 可微物理模拟器(Differentiable Physics Simulators)
1.1 数学推导
def step(u, p, dt):
# 计算非线性项
conv = convective_term(u) # ∂t u ≈ -(u·∇)u
visc = viscosity_term(u, nu) # ν Δ u
rhs = -grad(p) + conv + visc + f
# 隐式时间步:解线性系统
u_new = solve_linear_system(A(dt), b(rhs))
return u_new
1.2 梯度优化
。
理论保证:若算子 RkRk 连续可微且满足 Lipschitz 条件,则整体映射 ΦΦ 亦连续可微;优化问题为凸(或至少局部可导)时,标准梯度下降/Adam 等方法收敛。
2. 神经闭包模型(Neural Closure Models)
在 LES / RANS 等大尺度模拟中,需要 闭合 未解析的次格子尺度项。传统做法使用经验模型(Smagorinsky、Reynolds stress models)。神经网络可学习更精确、更通用的闭包。
2.1 闭包数学框架
。
2.2 网络训练目标
2.3 理论分析
3. AI 加速的迭代求解器
3.1 AI 预测粗化算子
3.2 AI 初始猜测
3.3 理论保证
- 若 AI 输出满足 谱半径 < 1,则对应迭代器保持收敛。
- 可通过 正则化约束(例如 ∥R∥F2≤C∥R∥F2≤C)或 投影层 确保输出在合法范围内。
4. 挑战与开放问题
问题 | 现状 | 潜在解决方向 |
---|---|---|
数值稳定性 | AI 模型可能产生不物理的张量,导致谱半径 > 1 | 加入物理约束层;使用可微正则化;设计鲁棒的训练数据集 |
收敛保证 | 对混合模型缺乏严格理论 | 推导基于 Lyapunov 函数的收敛条件;研究梯度裁剪对迭代稳定性的影响 |
高维泛化 | 训练样本有限,难以覆盖所有流态 | 使用自监督学习、元学习或强化学习生成多样化数据 |
可解释性 | 神经网络输出缺乏物理直观 | 开发可视化工具;将网络结构映射到传统模型(如 Smagorinsky) |
计算成本 | AI 模型自身需要 GPU/TPU 训练与推断 | 采用轻量级网络、稀疏参数化、知识蒸馏 |
多尺度一致性 | 细网格与粗网格之间的物理不匹配 | 使用自适应网格细化 + AI 调节策略 |
5. 小结
- 可微物理模拟器:实现端到端梯度,可用于参数识别、控制和逆问题。
- 神经闭包模型:在 LES / RANS 等大尺度流体仿真中提供更精确、更通用的子格尺度张量预测,理论上可逼近任何连续闭包函数,并通过正则化保证数值稳定性。
- AI 加速迭代求解器:利用网络学习粗化算子或初始猜测,显著减少多级迭代次数与计算成本,同时保持收敛性。
通过上述三条路径,我们实现了 AI 与传统数值方法的 共生——AI 提升性能、拓展能力;传统方法提供稳定性、物理一致性。未来的研究将聚焦于理论分析(如 Lyapunov 稳定性)、可解释性以及自适应训练框架,以进一步提升混合模型在高性能计算中的可靠性与实用价值。
6.完整代码实现
一、JAX + Haiku 版本 (jax_diffusive_closure.py
)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
1-D Diffusion with Neural Closure Model (Differentiable Physics)
=================================================================
* Uses JAX + Haiku for automatic differentiation.
* Implements a classic FDM solver for the heat equation:
du/dt = nu * d2u/dx2
but adds a learned neural closure term C(u) to represent subgrid effects.
* Trains the network on synthetic data generated by a high‑resolution reference
simulation (so-called “ground truth”).
* Demonstrates how to embed an AI component into a traditional HPC workflow.
Author: OpenAI ChatGPT
Date: 2025-08-19
"""
import os
import math
import time
from typing import Tuple
import numpy as np
import jax
import jax.numpy as jnp
import haiku as hk
import optax
# ----------------------------------------------------------------------
# 1. Simulation parameters
# ----------------------------------------------------------------------
NX = 128 # number of grid points (coarse mesh)
NT = 200 # number of time steps
DX = 1.0 / NX # spatial step
DT = 0.001 # temporal step
NU = 0.01 # diffusivity
# ------------------------------------------------------------------------------
# 2. Reference high‑resolution simulation for training data
# ------------------------------------------------------------------------------
def reference_solution(x: np.ndarray, t_final: float) -> np.ndarray:
"""
Generate a synthetic “ground truth” by solving the diffusion equation on
a fine mesh (4× finer). This is only used to create training labels.
The analytical solution of the heat equation with an initial Gaussian
bump can be computed exactly; we use that for speed.
Parameters
----------
x : np.ndarray
1‑D spatial coordinates on the coarse grid.
t_final : float
Final time at which to evaluate.
Returns
-------
u_ref : np.ndarray
Reference solution on the coarse grid at t_final.
"""
# Analytical solution: Gaussian spreading
sigma0 = 0.05
return jnp.exp(-((x - 0.5) ** 2) / (4 * NU * t_final + sigma0**2))
# ----------------------------------------------------------------------
# 3. Differentiable FDM solver with closure term
# ----------------------------------------------------------------------
def fdm_step(u: jnp.ndarray,
closure_fn: hk.Module,
nu: float = NU,
dx: float = DX,
dt: float = DT) -> jnp.ndarray:
"""
One explicit Euler step for the 1‑D diffusion equation with a neural
closure term. The boundary conditions are periodic.
Parameters
----------
u : jnp.ndarray
Current field (shape [NX]).
closure_fn : hk.Module
Haiku module that returns C(u) of shape [NX].
nu, dx, dt : float
Physical and numerical parameters.
Returns
-------
u_next : jnp.ndarray
Field after one time step.
"""
# Periodic second derivative (central difference)
d2u = (jnp.roll(u, -1) - 2 * u + jnp.roll(u, 1)) / dx**2
# Neural closure term: C(u)
c_u = closure_fn(u)
# Explicit Euler update
return u + dt * (nu * d2u + c_u)
def run_simulation(init_u: jnp.ndarray,
closure_params: hk.Params,
closure_state: hk.State,
num_steps: int) -> Tuple[jnp.ndarray, hk.State]:
"""
Run the simulation for `num_steps` steps and return the final field.
Parameters
----------
init_u : jnp.ndarray
Initial condition.
closure_params : hk.Params
Trained parameters of the closure network.
closure_state : hk.State
Any state (none for stateless MLP).
num_steps : int
Number of time steps to evolve.
Returns
-------
u_final, new_state
"""
# Closure function as a Haiku transform
def closure_fn(u):
return closure_network.apply(closure_params, closure_state, None, u)
@jax.jit
def step(u, _):
return fdm_step(u, closure_fn), None
u = init_u
for _ in range(num_steps):
u, _ = step(u, None)
return u, closure_state
# ----------------------------------------------------------------------
# 4. Neural closure model (simple MLP)
# ----------------------------------------------------------------------
class ClosureNetwork(hk.Module):
"""MLP that maps the local field to a scalar closure term."""
def __init__(self,
hidden_sizes: Tuple[int] = (32, 32),
name=None):
super().__init__(name=name)
self.hidden_sizes = hidden_sizes
def __call__(self, u: jnp.ndarray) -> jnp.ndarray:
# Reshape to [batch, features]
x = u[:, None] # shape (NX,1)
for h in self.hidden_sizes:
x = hk.Linear(h)(x)
x = jax.nn.relu(x)
out = hk.Linear(1)(x) # scalar output per grid point
return out.squeeze(-1) # shape (NX,)
# Instantiate the network
closure_network = ClosureNetwork(hidden_sizes=(64, 64))
# Haiku transform for parameter management
closure_transformed = hk.without_apply_rng(hk.transform(ClosureNetwork))
Training loop
# ----------------------------------------------------------------------
# 5. Training routine
# ----------------------------------------------------------------------
def loss_fn(params: hk.Params,
state: hk.State,
init_u: jnp.ndarray,
target_u: jnp.ndarray) -> jnp.ndarray:
"""
Compute L2 loss between simulation output and reference solution.
"""
u_sim, _ = run_simulation(init_u, params, state, NT)
return jnp.mean((u_sim - target_u)**2)
@jax.jit
def update(params: hk.Params,
opt_state: optax.OptState,
init_u: jnp.ndarray,
target_u: jnp.ndarray) -> Tuple[hk.Params, optax.OptState]:
grads = jax.grad(loss_fn)(params, None, init_u, target_u)
updates, new_opt_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state
def train(num_epochs: int = 2000,
batch_size: int = 32) -> hk.Params:
# Random seed
rng = jax.random.PRNGKey(42)
# Initialize parameters on a dummy input
sample_input = jnp.zeros((NX,))
params = closure_transformed.init(rng, sample_input)
state = None
# Optimizer
global optimizer # defined outside for simplicity
optimizer = optax.adamw(learning_rate=1e-3)
opt_state = optimizer.init(params)
# Training loop
for epoch in range(num_epochs):
rng, subkey = jax.random.split(rng)
# Sample random initial conditions (Gaussian bumps with noise)
centers = jax.random.uniform(subkey, shape=(batch_size,), minval=0.1, maxval=0.9)
widths = jax.random.uniform(subkey, shape=(batch_size,), minval=0.01, maxval=0.05)
init_conds = []
targets = []
for c, w in zip(centers, widths):
u0 = jnp.exp(-((x_grid - c) ** 2) / (w**2))
# Normalize to [0,1]
u0 = (u0 - u0.min()) / (u0.max() - u0.min())
init_conds.append(u0)
targets.append(reference_solution(x_grid, t_final=DT * NT))
init_batch = jnp.stack(init_conds) # shape [B,NX]
target_batch = jnp.stack(targets)
# For simplicity we train on one sample at a time
params, opt_state = update(params, opt_state,
init_batch[0], target_batch[0])
if epoch % 200 == 0:
l = loss_fn(params, None, init_batch[0], target_batch[0])
print(f"Epoch {epoch:04d} | Loss {l:.6e}")
return params
Main entry point
if __name__ == "__main__":
# Spatial grid on the coarse mesh (periodic domain [0,1))
x_grid = jnp.linspace(0.0, 1.0 - DX, NX)
print("Training neural closure model...")
trained_params = train(num_epochs=2000, batch_size=32)
# Evaluate on a new initial condition
rng = jax.random.PRNGKey(123)
c_new = jax.random.uniform(rng, minval=0.2, maxval=0.8)
w_new = 0.03
u0_new = jnp.exp(-((x_grid - c_new) ** 2) / (w_new**2))
u0_new = (u0_new - u0_new.min()) / (u0_new.max() - u0_new.min())
# Reference solution
u_ref = reference_solution(x_grid, t_final=DT * NT)
# Simulate with trained closure
u_pred, _ = run_simulation(u0_new, trained_params, None, NT)
print("Final L2 error (trained closure):", jnp.linalg.norm(u_pred - u_ref) / jnp.sqrt(NX))
二、Taichi 版本 (taichi_example.py
)
Taichi 的 API 与 JAX 略有差别,但思路相同。下面给出一个可直接跑的例子,演示如何在 Taichi 中实现 可微分物理模拟器(需要 taichi-1.6+
并开启 ti.enable_autodiff()
)。
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Taichi implementation of the same 1‑D diffusion + neural closure idea.
Uses Taichi's automatic differentiation (autodiff) to compute gradients.
Author: OpenAI ChatGPT
Date: 2025-08-19
"""
import taichi as ti
import numpy as np
ti.init(arch=ti.cpu, default_fp=ti.f32)
# Enable autodiff for all fields that need gradients
ti.cfg.allow_autodiff = True
NX = 128
DT = 0.001
DX = 1.0 / NX
NU = 0.01
NT = 200
x_grid = ti.field(ti.f32, shape=NX)
for i in range(NX):
x_grid[i] = i * DX
# ----------------------------------------------------------------------
# Neural closure: simple MLP implemented with Taichi kernels
# ----------------------------------------------------------------------
class ClosureMLP:
def __init__(self, input_dim=1, hidden_sizes=[64, 64], output_dim=1):
self.layers = []
prev = input_dim
for h in hidden_sizes:
layer = {
'W': ti.field(ti.f32, shape=(h, prev)),
'b': ti.field(ti.f32, shape=h)
}
# Xavier init
limit = np.sqrt(6.0 / (prev + h))
layer['W'][:] = np.random.uniform(-limit, limit, size=(h, prev)).astype(np.float32)
layer['b'][:] = 0.0
self.layers.append(layer)
prev = h
# Output layer
out_layer = {
'W': ti.field(ti.f32, shape=(output_dim, prev)),
'b': ti.field(ti.f32, shape=output_dim)
}
limit = np.sqrt(6.0 / (prev + output_dim))
out_layer['W'][:] = np.random.uniform(-limit, limit, size=(output_dim, prev)).astype(np.float32)
out_layer['b'][:] = 0.0
self.layers.append(out_layer)
@ti.func
def __call__(self, u: ti.types.ndarray):
# u shape (NX,)
x = u.reshape((NX, 1)) # (NX,1)
for layer in self.layers[:-1]:
W, b = layer['W'], layer['b']
x = ti.exp(ti.dot(W, x) + b[:, None]) # ReLU
# last layer
out = ti.dot(self.layers[-1]['W'], x) + self.layers[-1]['b'][:, None]
return out.squeeze() # shape (NX,)
closure_mlp = ClosureMLP()
# ----------------------------------------------------------------------
# Diffusion solver with closure
# ----------------------------------------------------------------------
@ti.kernel
def fdm_step(u: ti.types.ndarray,
u_next: ti.types.ndarray):
for i in range(NX):
ip = (i + 1) % NX
im = (i - 1 + NX) % NX
d2u = (u[ip] - 2.0 * u[i] + u[im]) / DX**2
c_u = closure_mlp(u)[i]
u_next[i] = u[i] + DT * (NU * d2u + c_u)
def run_simulation(init_u: np.ndarray) -> np.ndarray:
u = ti.field(ti.f32, shape=NX)
u_n = ti.field(ti.f32, shape=NX)
u[:] = init_u.astype(np.float32)
for _ in range(NT):
fdm_step(u, u_n)
u[:] = u_n[:]
return u.to_numpy()
# ----------------------------------------------------------------------
# Training
# ----------------------------------------------------------------------
def loss_fn(init_u: np.ndarray,
target_u: np.ndarray) -> ti.f32:
# Forward simulation
pred = run_simulation(init_u)
diff = pred - target_u
return (diff * diff).sum() / NX
optimizer = ti.optimizers.Adam(learning_rate=1e-3)
@ti.kernel
def train_step():
optimizer.zero_grad()
# Sample random init
c = ti.random(ti.f32) * 0.8 + 0.1
w = ti.random(ti.f32) * 0.04 + 0.01
u0 = ti.exp(-((x_grid - c)**2) / (w**2))
target = np.exp(-((x_grid.to_numpy() - 0.5)**2) /
(4 * NU * DT * NT + 0.05**2)).astype(np.float32)
loss = loss_fn(u0, target)
loss.backward()
optimizer.step()