Skip to content

张量神经网络 TNN

参考 Tensor Neural Network and Its Numerical Integration

特点:

  • 网络输出是函数变量分离的形式
  • 将损失函数中高维积分的计算转化成一维积分的计算,极大缓解了维度灾难
  • 适用于高维弱形式问题

弱形式偏微分方程求解

考虑如下 poisson equation:

Δu=ex(x2+y3+6y),x[0,1]2u(0,y)=y3u(1,y)=(1+y3)e1u(x,0)=xexu(x,1)=ex(x+1)

其真实解为 u(x,y)=ex(x+y3)

现在考虑使用 TNN 求解这个问题。

步骤 1 导入模型

导入必要的库:

python
import torch
import nets
import myfuntools as mft
from nets.tnn import Tnnfn
mdl = nets.TNN([2,20,20,1])

该语句可以方便地构建一个 TNN,其中 [2,20,20,1] 代表:

  • 2:输入神经元的数量为 2,对应问题的自变量 x,y
  • 20:第一个隐藏层的神经元数量为 20
  • 20:第二个隐藏层的神经元数量为 20
  • 1:输出神经元的数量为 1,对应问题的 u(x,y)

步骤 2 翻译

将求解问题翻译成代码。不同于 PINN, TNN 所考虑的是问题的弱形式,对于一个 poisson equation:

{Δu(x)+b(x)u(x)=f(x),xΩ,u(x)=0,xΩ.

其弱形式为:

Ωuvdx+Ωb(x)uvdx=Ωfvdx.

求解弱形式时,我们会将神经网络输出层的前一层看作方程解的基函数,于是上面对应一个线性方程组,我们记作:

Ac=F

其中 c 是网络的最后一层参数,也是我们要求解的参数。然而,我们的问题并非是一个齐次的 poisson equation,我们的做法是做一个变量变换,将其变为齐次的方程。假设神经网络所表达的基函数满足齐次边界条件,一个已知函数 A 满足非齐次边界条件,那么我们可是使用下面的表达式来作近似解:

u(x)=TNN(x)+A(x)

于是

{ΔTNN(x)=f+ΔA,xΩ,TNN(x)=0,xΩ.

这样非齐次边界问题就被转化为了齐次边界问题。

在本例中,我们的边界是区域 [0,1]2 的边界,我们可以对 MLP 层的输出乘以一个边界函数 x(1x) 迫使 TNN 在边界上为 0:

python
mdl.set_bcfn(lambda x: x*(1-x))

# 或者在定义的时候直接指定
# mdl = nets.TNN([2,20,20,1],bcfn=lambda x: x*(1-x))

接下来我们就要构造一个函数 A,使其恰好满足非齐次边界条件,并且该函数还要是 TNN 形式的,即可变量分离的形式。

python
def get_f(X):
    x, y = X.T.unsqueeze(-1)
    d1 = torch.cat([-torch.exp(-x) * (x - 2), -torch.exp(-x)], 1)
    d2 = torch.cat([torch.ones_like(y), y**3 + 6 * y], 1)
    return torch.stack([d1, d2], dim=1)  # [N,D,P]=[N,2,2]

def u0y(y):
    return y**3

def u1y(y):
    return (1 + y**3) / torch.e

def ux0(x):
    return x * torch.exp(-x)

def ux1(x):
    return (1 + x) * torch.exp(-x)

bc = u0y, u1y, ux0, ux1

我们将每个条件涉及的函数都写成了 TNN 的形式,对于二维问题的"四个"边界条件,下面的方法可以构造出符合要求的函数 A,该函数一定满足边界条件且可以变量分离:

python
Afn, ddAfn = mdl.get_Afn_and_ddAfn_by_bc(*bc)

其中函数 Afn 满足边界条件,具有 TNN 形式(不是一个 TNN, 只是一个普通 python 函数),ddAfn 代表 ΔA

方程的右端源项是 f+ΔA, 在 TNN 中,两个 TNN 形式的函数并非是最终的输出结果直接相加,而是 TNN 形式数据的拼接。我们有两种方法可以做到这一点,一种是直接拼接, 另一种是 TNN 加法:

python
f = Tnnfn(get_f,P=2)
dA = Tnnfn(ddAfn,P=4)
f_add_dA = f + dA

使用 Tnnfn 包裹的函数做加减乘幂时,会被转化为 TNN 加减乘幂的相关运算。现在我们已经准备好了条件:

  • mdl, 满足齐次边界条件的 TNN
  • f_add_dA, 右端源项函数

步骤 3 配置单

接下来我们要给出一个配置单(sets), 我们的训练函数是 pde_solve,当我们调出这个函数时,编辑器会自动弹出一份文档,里面有可参考的配置信息,我们也可以直接复制:

python
x1, w1 = mft.composite_quadrature_1d(15, 0, 1, 1)
x2, w2 = mft.composite_quadrature_1d(15, 0, 1, 1)
X = torch.cat([x1,x2],1) # [N,D]
W = torch.cat([w1,w2],1)
Xg = X.clone().requires_grad_(True)
sets = {
    "x": Xg,
    "w": W,
    "f": f_add_dA.tnn(X),
    "optimizer": torch.optim.LBFGS,
    "lr": 1,
    "epochs": 100,
}

这里的 X, W 分别是弱形式问题对应的积分点和积分权重。f_add_dA.tnn(X) 给出的是 TNN 形式的数据(一个形状为 [N,D,P] 的 tensor),如果是 f_add_dA(X) 则给出的是函数直接的输出,即下面是成立的:

python
assert torch.allclose(f_add_dA(X), f(X) + dA(X))
assert torch.allclose(f_add_dA.tnn(X), torch.cat([f.tnn(X),dA.tnn(X)],dim=-1))
assert torch.allclose(f_add_dA.tnn(X), torch.cat([get_f(X),ddAfn(X)],dim=-1))

步骤 4 训练

将配置单传入 pde_solve 进行训练:

python
mdl.pde_solve(**sets3)

通常,对于数据量不大的问题,只需要数秒就能完成训练,本例在 i5 12 代 CPU 上花费 2s 左右。

步骤 5 可视化

可视化。本库本身不提供可视化的功能,我们可以使用 matplotlib 按需实现自己的可视化,下面是一个示例:

python
from matplotlib.colors import Normalize
import matplotlib.pyplot as plt

def myplot(mdl):
    norm = Normalize(vmin=0, vmax=1)
    t1 = torch.linspace(0, 1, 90).view(1, -1)
    t2 = torch.linspace(0, 1, 100).view(-1, 1)
    acc = torch.exp(-t1) * (t1 + t2**3)

    shape = (t1 + t2).shape
    p1 = t1.broadcast_to(shape)
    p2 = t2.broadcast_to(shape)
    pe = torch.cat((p1.flatten().view(-1, 1), p2.flatten().view(-1, 1)), dim=1)
    pre:Tensor = mdl(pe)

    fig = plt.figure(figsize=(12, 3))
    ax1 = fig.add_subplot(1, 3, 1)
    im1 = ax1.imshow(acc, extent=[0, 1, 0, 1], origin="lower", norm=norm)
    plt.colorbar(im1)

    ax2 = fig.add_subplot(1, 3, 2)
    im2 = ax2.imshow(pre.view_as(acc), extent=[0, 1, 0, 1], origin="lower", norm=norm)
    plt.colorbar(im2)

    ax3 = fig.add_subplot(1, 3, 3)
    im3 = ax3.imshow(abs(pre.view_as(acc) - acc), extent=[0, 1, 0, 1], origin="lower")
    print("最大绝对误差:",torch.max(abs(pre.view_as(acc) - acc)))
    plt.colorbar(im3)
    plt.tight_layout()

def predict(x):
    mdl.to('cpu')
    return mdl(x) + mdl.get_output_tnn(torch.ones(1,4),Afn(x))

with torch.device('cpu'):
    with torch.no_grad():
        myplot(predict)

从左到右依次是精确值、预测值、逐点绝对误差。

这里值得注意的是,TNN (也就是我们的 mdl) 的输出并非真实值的近似解,因为我们做过变量变换,真实值的近似解为

u(x)=TNN(x)+A(x)

A(x) 的输出我们可以使用

python
mdl.get_output_tnn(torch.ones(1,4),Afn(x))

获得,Afn(x) 只提供 TNN 形式的数据(形状为 [N,D,P]=[N,2,4]),get_output_tnn 是不依赖于实例 mdl 的方法,提供基函数的系数和 TNN 形式的数据就可以获得最终的输出,即下面是成立的:

python
assert torch.allclose(Tnnfn(Afn,P=4)(X),mdl.get_output_tnn(torch.ones(1,4),Afn(X)))
assert torch.allclose(Tnnfn(Afn,P=4)(X),nets.TNN.get_output_tnn(torch.ones(1,4),Afn(X)))

TIP

该例的精度可以继续提升. 在创建 TNN 时,可以指定 method:

python
mdl = nets.TNN([2,20,20,1],method="mix")

method 会影响神经网络参数的求解方式,有三种取值,ritz(默认),mix 和 strong:

  • ritz, 采用 ritz 形式的损失,参数由配置单中优化器求解(本例约 2s)
  • mix, 隐层采用优化器求解,最后一层采用 Garlerkin 求解(通常精度最高,速度也最快,本例约 0.4s)
  • strong, 当求解强形式时显式指定。

采用 mix 方法求得的结果:

TIP

TNN 有两种形式的输出,一种是网络的直接输出,输出特征为一维。另一种是 TNN 形式的数据,即形状为 [N,D,P] 的数据,

  • N 积分点数量
  • D 输入特征维度
  • P 单个 MLP 的输出特征维度

[N,D,P] 形式的数据是进行积分运算的关键,许多内置函数方法直接接收 [N,D,P] 形式的数据而非 TNN 的最终输出。我们提供了 Tnnfn, 将函数包括在其中可以方便的进行 TNN 意义下的加减乘幂。

假设 mdl 是张量神经网络,则下面是成立的

python
assert torch.allclose(Tnnfn(mdl).tnn(X), mdl.block[0](X))
assert torch.allclose(Tnnfn(mdl)(X), mdl(X))