1. 引言
1.1.卷积神经网络CNN
卷积神经网络(CNN)的数学模型是深度学习中用于处理图像和其他高维数据的关键组成部分。那么,CNN究竟是什么呢?
总结起来,CNN网络主要完成以下操作:
-
卷积操作(Convolution Operation):
卷积是CNN中的基本操作,用于提取图像特征。给定输入图像 ( I ) 和卷积核 ( K ),卷积操作可以表示为:
( I ∗ K ) ( x , y ) = ∑ m ∑ n I ( x − m , y − n ) ⋅ K ( m , n ) (I * K)(x, y) = \sum_m \sum_n I(x - m, y - n) \cdot K(m, n) (I∗K)(x,y)=∑m∑nI(x−m,y−n)⋅K(m,n)
其中 ( x , y ) (x, y) (x,y)表示输出特征图的位置, ( m , n ) (m, n) (m,n)是卷积核的滑动窗口位置。 -
卷积核(Convolutional Kernel):
卷积核是一个小的矩阵,用于在输入数据上滑动以产生特征图。每个卷积核可以学习到输入数据中的特定特征,如边缘、角点或更复杂的纹理模式。 -
填充(Padding):
填充是在输入图像的边缘添加零或特定值的操作,以控制卷积后输出特征图的大小。填充可以表示为:
P a d d e d ( I ) ( x , y ) = { I ( x − p , y − p ) if x − p ≥ 0 and y − p ≥ 0 0 otherwise Padded(I)(x, y) = \begin{cases} I(x - p, y - p) & \text{if } x - p \geq 0 \text{ and } y - p \geq 0 \\ 0 & \text{otherwise} \end{cases} Padded(I)(x,y)={I(x−p,y−p)0if x−p≥0 and y−p≥0otherwise
其中 p p p是填充的大小。 -
步长(Stride):
步长定义了卷积核在输入数据上滑动的间隔。步长为 s s s的卷积操作可以表示为:
( I ∗ K ) ( x , y ) = ∑ m ∑ n I ( ( x − m ) s , ( y − n ) s ) ⋅ K ( m , n ) (I * K)(x, y) = \sum_m \sum_n I((x - m)s, (y - n)s) \cdot K(m, n) (I∗K)(x,y)=∑m∑nI((x−m)s,(y−n)s)⋅K(m,n) -
激活函数(Activation Function):
激活函数用于在卷积层之后引入非线性,使得网络能够学习复杂的模式。ReLU(Rectified Linear Unit)是最常用的激活函数之一,定义为:
ReLU ( z ) = max ( 0 , z ) \text{ReLU}(z) = \max(0, z) ReLU(z)=max(0,z) -
池化(Pooling):
池化操作用于降低特征图的空间维度,通常使用最大池化或平均池化。最大池化可以表示为:
Max Pooling ( I ) ( x , y ) = max m , n ∈ window I ( ( x − m ) s , ( y − n ) s ) \text{Max Pooling}(I)(x, y) = \max_{m, n \in \text{window}} I((x - m)s, (y - n)s) Max Pooling(I)(x,y)=maxm,n∈windowI((x−m)s,(y−n)s) -
全连接层(Fully Connected Layer):
全连接层是CNN中的密集层,其中每个节点都与前一层的所有激活值相连。数学上,全连接层的输出可以表示为:
O = f ( W ⋅ A + b ) O = f(W \cdot A + b) O=f(W⋅A+b)
其中 O O O是输出, W W W是权重矩阵, A A A是前一层的激活值, b b b是偏置项,$f $ 是激活函数。 -
损失函数(Loss Function):
损失函数用于衡量模型预测与实际标签之间的差异。对于分类问题,交叉熵损失是常用的损失函数:
L = − ∑ c = 1 C y o , c log ( p o , c ) L = -\sum_{c=1}^C y_{o,c} \log(p_{o,c}) L=−∑c=1Cyo,clog(po,c)
其中 C C C是类别数, y o , c y_{o,c} yo,c是真实标签的one-hot编码, p o , c p_{o,c} po,c是模型预测的概率。 -
反向传播(Backpropagation):
反向传播是一种算法,用于计算损失函数关于网络参数的梯度,并更新这些参数以最小化损失。这个过程涉及到链式法则的应用。 -
优化器(Optimizer):
优化器定义了参数更新的规则。例如,SGD(随机梯度下降)的更新规则可以表示为:
W : = W − η ⋅ ∇ W L W := W - \eta \cdot \nabla_W L W:=W−η⋅∇WL
其中 η \eta η是学习率,$\nabla_W L $是损失相对于权重 W W W的梯度。
这些数学模型和操作构成了CNN的基础,使得它们能够从原始数据中自动学习特征并进行有效的模式识别。随着研究的发展,还有许多变体和改进,如深度可分离卷积、空洞卷积等,进一步增强了CNN的能力。
通过上面的讨论,我们将CNN定义:
R 2 \mathbb{R}^2 R2对信号 f : R 2 ⟶ R and K : R 2 ⟶ R at x ∈ R 2 \mathbb{f}:\mathbb{R}^2\longrightarrow\mathbb{R}\text{ and }\text{K}:\mathbb{R}^2\longrightarrow\mathbb{R}\text{ at }\mathbb{x}\in\mathbb{R}^2 f:R2⟶R and K:R2⟶R at x∈R2
( f ∗ k ) ( x ) = ∫ R 2 f ( x ~ ) k ( x ~ − x ) d x ~ , (f * k) (\mathbf{x}) = \int_{\mathbb{R}^2} f(\tilde{\mathbf{x}})k(\tilde{\mathbf{x}} - \mathbf{x}) \text{d}\tilde{\mathbf{x}}, (f∗k)(x)=∫R2f(x~)k(x~−x)dx~,
我们可以观察到,卷积操作实质上就是函数 f f f与其移动(或滑动)后的核函数 k k k之间的内积。
需要注意的是,在实际应用中,卷积神经网络(CNNs)实现的是这一操作的离散化版本。
( f ∗ k ) ( x ) = ∑ x ~ ∈ Z 2 f ( x ~ ) k ( x − x ~ ) Δ x ~ = ∑ x ~ ∈ Z 2 f ( x ~ ) k ( x − x ~ ) \begin{align} (f*k) (\mathbf{x}) &= \sum_{\mathbf{\tilde{x}} \in \mathbb{Z}^2} f(\mathbf{\tilde{x}})k(\mathbf{x}-\mathbf{\tilde{x}})\Delta\mathbf{\tilde{x}}\\ &= \sum_{\mathbf{\tilde{x}} \in \mathbb{Z}^2} f(\mathbf{\tilde{x}})k(\mathbf{x}-\mathbf{\tilde{x}}) \end{align} (f∗k)(x)=x~∈Z2∑f(x~)k(x−x~)Δx~=x~∈Z2∑f(x~)k(x−x~)
考虑到图像中像素的排列通常是等距的,我们将间隔 Δx 设置为1,这样公式 ( 1 ) (1) (1)就可以简化为公式(2)。为了简化理解,我们在本次回顾中采用连续域的概念进行说明。
在卷积神经网络的卷积层,例如PyTorch框架中的Conv2D层,卷积操作会在图像定义的范围内对每个位置执行。由于整个输入图像共享同一套卷积核权重,因此卷积层的输出对于图像的平移变换具有不变性。此外,输入图像通常由多个通道组成,这些通道在卷积过程中会被统一考虑。
本文将利用PyTorch提供的torch.nn.functional.conv2d()函数来自动完成在每个特征点上的卷积积分操作。这样做的好处是,我们无需手动编写卷积操作的代码,从而简化了开发过程。通过使用这个函数,我们可以轻松地在输入特征图的每个位置上应用卷积核,并自动处理多通道数据的求和,实现特征提取的自动化。
1.2.图卷积网络GCNN
图卷积网络(Graph Convolutional Networks, GCNNs)的数学表达通常依赖于图论的概念,特别是图上的信号处理。
-
图定义:
图 G G G 由节点集合 V V V 和边集合 E E E 组成。对于加权图,每条边 e i j e_{ij} eij 都有一个权重 w i j w_{ij} wij。 -
邻接矩阵 A A A:
邻接矩阵 A A A 是一个 N × N N \times N N×N 的矩阵,其中 N N N 是图中节点的数量。如果节点 i i i 和 j j j 之间有边,则 A i j = w i j A_{ij} = w_{ij} Aij=wij,否则 A i j = 0 A_{ij} = 0 Aij=0。 -
度矩阵 D D D:
度矩阵 D D D 是一个对角矩阵,其中 D i i D_{ii} Dii 是节点 i i i 的度,即与节点 i i i 相连的边的权重之和。 -
归一化邻接矩阵 A ^ \hat{A} A^:
归一化邻接矩阵 A ^ \hat{A} A^ 是通过 A A A 除以其对应节点的度来计算的,即
A ^ = D − 1 2 A D − 1 2 \hat{A} = D^{-\frac{1}{2}}AD^{-\frac{1}{2}} A^=D−21AD−21。 -
图卷积操作:
图卷积操作通常定义为节点特征 X X X 和可学习的权重矩阵 K K K 之间的卷积,可以表示为 Y = σ ( A ^ X K ) Y = \sigma(\hat{A}XK) Y=σ(A^XK),其中 σ \sigma σ 是一个非线性激活函数。 -
节点特征 X X X:
节点特征 X X X 是一个 N × F N \times F N×F 的矩阵,其中 F F F 是每个节点的特征维度。 -
权重矩阵 K K K:
权重矩阵 K K K 是一个 F × F ′ F \times F' F×F′ 的矩阵,其中 F ′ F' F′ 是卷积后的特征维度。 -
图卷积层:
图卷积层可以堆叠,每一层都可能有不同的权重矩阵,并且可以学习不同尺度的特征。 -
池化操作:
图池化操作用于减少节点的数量,同时保持重要的结构特征。 -
等变性:
GCNNs设计为对特定类型的图变换保持等变性,例如子图同构性或图的平移等。
GCNNs的关键在于它们能够在图结构数据上进行卷积操作,捕捉节点间的局部连接模式,同时保持对图结构的尊重。这使得GCNNs非常适合处理社交网络、分子结构、交通网络等图结构数据。
根据上面的讨论,我们对GCNN进行如下定义:
我们将对二维图像进行操作,这些图像通常定义在 R 2 \mathbb{R}^2 R2上。因此,在构建一个能够跟踪输入中特征在何种姿态(即:来自群 G = R 2 ⋊ H G=\mathbb{R}^2\rtimes H G=R2⋊H的变换)下出现的网络的第一步,我们需要将信号转移到另一个域,在这个域中,不同姿态下的相同特征是解耦的。它通过提升卷积实现,它将输入信号 f i n : R 2 → R f_{in}:\mathbb{R}^2\rightarrow \mathbb{R} fin:R2→R中的特征映射到群 G G G 上的特征映射 f o u t f_{out} fout: G → R \text{G}\rightarrow \mathbb{R} G→R。对于在 R 2 \mathbb{R}^2 R2 上定义的信号 f f f 和核 k k k,以及群元素 g = ( x , h ) ∈ G = R 2 ⋊ H g=(\boldsymbol{x}, h) \in G=\mathbb{R}^2 \rtimes H g=(x,h)∈G=R2⋊H:
( f ∗ lifting k ) ( g ) = ∫ R 2 f ( x ~ ) k h ( x ~ − x ) d x ~ . (f *_{\text{lifting}} k) (g) = \int_{\mathbb{R}^2} f(\tilde{\mathbf{x}})k_h(\tilde{\mathbf{x}} - \mathbf{x}) \,{\rm d}\tilde{\mathbf{x}}. (f∗liftingk)(g)=∫R2f(x~)kh(x~−x)dx~.
其中, k h k_h kh 是核 k : R 2 → R k:\mathbb{R}^2 \rightarrow \mathbb{R} k:R2→R 在群元素 h ∈ H h \in H h∈H 的正则表示 L h \mathcal{L}_h Lh 下的变换; k h = 1 ∣ h ∣ L h [ k ] k_h = \frac{1}{|h|}\mathcal{L}_{h}[k] kh=∣h∣1Lh[k];
因子 1 ∣ h ∣ \frac{1}{|h|} ∣h∣1,其中 ∣ h ∣ |h| ∣h∣ 是群元素 h h h 在 R 2 \mathbb{R}^2 R2 中的矩阵表示的行列式,它考虑了 h h h 可能在 R 2 \mathbb{R}^2 R2 上引起的体积变化。当我们处理循环群时,我们不会遇到这个问题(旋转矩阵的行列式为1, R 2 \mathbb{R}^2 R2 上的旋转不改变体积),但如果你想要实现例如膨胀群的等变性,这就变得重要了。
接下来,既然我们已经在群上定义了特征映射 f o u t : G → R f_{out}:G\rightarrow \mathbb{R} fout:G→R,我们就应用群卷积,将卷积操作扩展到整个群 G G G 上的积分;
( f ∗ g r o u p k ) ( g ) = ∫ G f ( g ~ ) k ( g − 1 ⋅ g ~ ) d g ~ = ∫ R 2 ∫ H f ( x ~ , h ~ ) L x L h k ( x ~ , h ~ ) 1 ∣ h ∣ d x ~ d h ~ = ∫ R 2 ∫ H f ( x ~ , h ~ ) k ( h − 1 ( x ~ − x ) , h − 1 ⋅ h ~ ) 1 ∣ h ∣ d x ~ d h ~ . \begin{aligned} (f *_{\mathrm{group}} k) (g) &=\int_G f(\tilde{g})k(g^{-1} \cdot \tilde{g}) {\rm d}\tilde{g} \\ &=\int_{\mathbb{R}^2}\int_H f(\tilde{\mathbf{x}}, \tilde{h})\mathcal{L}_{x}\mathcal{L}_{h}k(\tilde{\mathbf{x}}, \tilde{h})\dfrac{1}{|h|} \,{\rm d}\mathbf{\tilde{x}}\,{\rm d}\tilde{h}\\ &=\int_{\mathbb{R}^2}\int_H f(\tilde{\mathbf{x}},\tilde{h})k({h^{-1}}(\tilde{\mathbf{x}}-\mathbf{x}), h^{-1}\cdot \tilde{h})\dfrac{1}{|h|} \,{\rm d}\mathbf{\tilde{x}}\,{\rm d}\tilde{h}. \end{aligned} (f∗groupk)(g)=∫Gf(g~)k(g−1⋅g~)dg~=∫R2∫Hf(x~,h~)LxLhk(x~,h~)∣h∣1dx~dh~=∫R2∫Hf(x~,h~)k(h−1(x~−x),h−1⋅h~)∣h∣1dx~dh~.
提升卷积的主要区别在于信号和核函数 f , k f,k f,k 都是定义在 G G G 上的函数: G → R G\rightarrow \mathbb{R} G→R;
,而积分也反映了这一点,它覆盖了整个群 G G G。除此之外,就没有太大区别了!
在经过一定数量的这样的群卷积层之后,我们最终想要得到一个对群作用不变的表示。我们可以通过执行一个投影来实现这一点,该投影将我们定义在 G G G 上的函数折叠到一个单点,并且该操作对群作用是不变的(如求和、平均、最大值、最小值)。
1.3.安装和导入软件包
在本节中,我们将安装并引入一些在本教程中会用到的库。PyTorch被选为我们的深度学习框架。为了简化模型的训练和跟踪过程,我们还使用了PyTorch Lightning这一库。
## 标准库
import os # 用于操作系统功能,如文件路径
import numpy as np # 用于科学计算
import math # 包含数学函数
from PIL import Image # 用于图像处理
from functools import partial # 用于函数的部分(偏)应用## 绘图相关导入
import matplotlib # 用于绘图
import matplotlib.pyplot as plt # 用于创建图表
%matplotlib inline # 在Jupyter Notebook中使图表在代码单元格内显示## PyTorch
import torch # PyTorch深度学习框架
import torch.nn as nn # 包含神经网络模块
import torch.utils.data as data # 包含数据加载和处理工具
import torch.optim as optim # 包含优化算法## Torchvision
import torchvision # 用于处理图像和视频的PyTorch扩展包
from torchvision.datasets import MNIST # 导入MNIST数据集
from torchvision import transforms # 包含图像变换操作## PyTorch Lightning
try:import pytorch_lightning as pl # 导入PyTorch Lightning,简化训练过程
except ModuleNotFoundError: # 如果模块未找到异常# Google Colab默认没有安装PyTorch Lightning。如果需要,我们在这里安装!pip3 install pytorch-lightning>=1.4 --quiet # 使用pip命令安装PyTorch Lightningimport pytorch_lightning as pl # 再次尝试导入
import pytorch_lightning as pl # 导入PyTorch Lightning库
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint # 导入回调函数,用于监控学习率和模型检查点
# 数据集下载路径,例如MNIST数据集
DATASET_PATH = "../../data"
# 预训练模型保存路径
CHECKPOINT_PATH = "../../saved_models/DL2/GDL"# 确保在GPU上的所有操作是确定性的(如果使用GPU),以保证结果的可复现性
torch.backends.cudnn.deterministic = True # 设置为True以确保每次运行结果相同
torch.backends.cudnn.benchmark = False # 设置为False以关闭 cudnn 库的基准测试模式,这有助于提高确定性# 根据系统是否拥有可用的GPU来分配设备,如果没有GPU,则使用CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import urllib.request # 导入urllib.request模块,用于请求URLs
from urllib.error import HTTPError # 导入HTTPError,用于捕获HTTP请求错误# GitHub上存储本教程预训练模型的URL
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/DL2/GDL/"
# 需要下载的文件列表
files = ["paprika.tiff"]
# 如果保存模型的路径不存在,则创建它
os.makedirs(CHECKPOINT_PATH, exist_ok=True) # 使用os模块创建目录,如果目录已存在不会抛出异常# 对于列表中的每个文件,检查它是否已经存在。如果不存在,则尝试下载
for file_name in files:file_path = os.path.join(CHECKPOINT_PATH, file_name) # 拼接文件的完整路径if not os.path.isfile(file_path): # 检查文件是否已经存在file_url = base_url + file_name # 构造文件的URLprint(f"Downloading {file_url}...") # 打印下载信息try:urllib.request.urlretrieve(file_url, file_path) # 尝试下载文件except HTTPError as e: # 如果发生HTTP错误print("Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n", e)# 打印错误信息,并提示用户从其他渠道下载文件或联系作者
2.群论基础
2.1.群的定义
在我们深入群论的高级概念之前,先来回顾一下群的基本概念。群是一个由集合 G G G 和二元运算 ⋅ \cdot ⋅ 组成的数学对象。集合 G G G 包含了所有的群元素,而运算 ⋅ \cdot ⋅ 定义了这些元素如何相互作用。群的运算必须满足以下条件:
- 封闭性:对于任意两个群元素 g 1 , g 2 ∈ G g_1, g_2 \in G g1,g2∈G,它们的运算结果 g 1 ⋅ g 2 g_1 \cdot g_2 g1⋅g2 也必须在 G G G 中。
- 单位元存在性:存在一个特殊元素 e ∈ G e \in G e∈G,使得对于所有 g ∈ G g \in G g∈G,都有 e ⋅ g = g ⋅ e = g e \cdot g = g \cdot e = g e⋅g=g⋅e=g。
- 逆元存在性:对于每个元素 g ∈ G g \in G g∈G,都有一个对应的逆元素 g − 1 ∈ G g^{-1} \in G g−1∈G,使得 g ⋅ g − 1 = g − 1 ⋅ g = e g \cdot g^{-1} = g^{-1} \cdot g = e g⋅g−1=g−1⋅g=e。
- 结合律:对于任意三个群元素 g 1 , g 2 , g 3 ∈ G g_1, g_2, g_3 \in G g1,g2,g3∈G,运算的顺序不影响结果,即 ( g 1 ⋅ g 2 ) ⋅ g 3 = g 1 ⋅ ( g 2 ⋅ g 3 ) (g_1 \cdot g_2) \cdot g_3 = g_1 \cdot (g_2 \cdot g_3) (g1⋅g2)⋅g3=g1⋅(g2⋅g3)。
群还可以对 R 2 \mathbb{R}^2 R2 上定义的函数进行变换,这通过正则表示 L g G → R 2 \mathcal{L}_g^{\mathbb{G}\rightarrow \mathbb{R}^2} LgG→R2 实现,我们通常简写为 L g \mathcal{L}_g Lg。正则表示定义为:
[\mathcal{L}_g f (\mathbf{x}) = f(g^{-1} \cdot \mathbf{x})]
这里, g − 1 ⋅ x g^{-1} \cdot \mathbf{x} g−1⋅x 表示逆元素 g − 1 g^{-1} g−1 对向量 x \mathbf{x} x 的作用。正则群卷积的名称来源于它使用这种正则表示来变换网络中使用的卷积核 k k k。
2.2. Python中群的实现
接下来,我们从创建一个基类开始,这个基类定义了在实现群卷积神经网络时所需的函数和属性。由于我们将使用 PyTorch 框架来实现群卷积神经网络,我们也将把群实现为一个 PyTorch 模块。
我们首先定义一个名为 GroupBase
的基类,它包含了实现群卷积所必需的所有属性和操作。在实现群卷积时,实现这个基类中定义的函数是扩展到其他群的必要和充分条件。简单来说,如果你想要为一个你感兴趣的新群实现群卷积,只需继承这个基类并实现其方法即可。(实际上,这种方法主要适用于离散的、紧凑的群)
class GroupBase(torch.nn.Module):# 继承自PyTorch的nn.Module类,用于实现群的基类def __init__(self, dimension, identity):"""构造函数,用于初始化群的属性。@param dimension: 群的维度(在代数的基中维度的数量)。@param identity: 群的单位元素。"""super().__init__() # 调用基类的构造函数self.dimension = dimension # 群的维度# 注册群的单位元素为一个不会参与梯度计算的缓冲区self.register_buffer('identity', torch.Tensor(identity))def elements(self):"""获取一个张量,其中包含此群中的所有群元素。应实现具体的子类中。"""raise NotImplementedError()def product(self, h, h_prime):"""定义两个群元素的群乘积。@param h: 群元素1@param h_prime: 群元素2应实现具体的子类中。"""raise NotImplementedError()def inverse(self, h):"""定义群元素的逆。@param h: 子群H中的一个群元素应实现具体的子类中。"""raise NotImplementedError()def left_action_on_R2(self, h, x):"""子群H中的一个元素对R2中的向量的群作用。@param h: 子群H中的一个群元素@param x: R2中的向量应实现具体的子类中。"""raise NotImplementedError()def matrix_representation(self, h):"""获取群元素h在R^2中的矩阵表示。@param h: 群元素应实现具体的子类中。"""raise NotImplementedError()def determinant(self, h):"""计算群元素h的表示的行列式。@param h: 群元素应实现具体的子类中。"""raise NotImplementedError()def normalize_group_parameterization(self, h):"""将群元素映射到区间[-1, 1]。我们使用这个函数来创建一个标准化的输入,以便在群上获得权重。@param h: 群元素应实现具体的子类中。"""raise NotImplementedError()
2.3. 实现群旋转 C 4 C_4 C4
为了说明,我们考虑一个相对简单的群——循环群 C 4 C_4 C4,它代表了平面上所有的 90 ° 90° 90° 旋转。以下是关于这个群的一些要点:
- C 4 C_4 C4 的群元素集合为 G : = { e , g , g 2 , g 3 } G := \{ e, g, g^2, g^3\} G:={e,g,g2,g3}。这些群元素可以用旋转角度 θ \theta θ 来表示,即 e e e 对应 0 ° 0° 0°, g g g 对应 90 ° 90° 90°, g 2 g^2 g2 对应 180 ° 180° 180°,以此类推。
- 群的乘法通过角度相加模 2 π 2\pi 2π 来定义,即 g ⋅ g ′ = θ + θ ′ m o d 2 π g \cdot g' = \theta + \theta' \mod 2\pi g⋅g′=θ+θ′mod2π。
- 群的逆元通过角度取负模 2 π 2\pi 2π 来获得,即 g − 1 = − θ m o d 2 π g^{-1} = -\theta \mod 2\pi g−1=−θmod2π。
- 循环群 C 4 C_4 C4 在二维欧几里得平面 R 2 \mathbb{R}^2 R2 上的作用是通过旋转矩阵 R θ R_\theta Rθ 实现的,其中
R θ = [ cos ( θ ) − sin ( θ ) sin ( θ ) cos ( θ ) ] . R_{\theta} = \begin{bmatrix} \cos(\theta) & -\sin(\theta) \\ \sin(\theta) & \cos(\theta) \end{bmatrix}. Rθ=[cos(θ)sin(θ)−sin(θ)cos(θ)].
- 这使得我们可以定义在 R 2 \mathbb{R}^2 R2 上定义的函数 f f f 上的正则表示 L θ \mathcal{L}_\theta Lθ,具体形式为
L θ f ( x ) = f ( R − θ m o d 2 π x ) . \mathcal{L}_{\theta} f(\mathbf{x}) = f(R_{-\theta\mod2\pi}\mathbf{x}). Lθf(x)=f(R−θmod2πx).
这样,我们就能够用数学工具来描述和操作 C 4 C_4 C4 群的元素和它们在二维平面上的作用。
以下代码定义了一个名为 CyclicGroup
的类,它继承自之前提到的 GroupBase
类,并实现了一个旋转群的基本操作。
class CyclicGroup(GroupBase):# 构造函数,初始化循环群的属性def __init__(self, order):super().__init__(dimension=1,identity=[0.])# 确保阶数大于1assert order > 1self.order = torch.tensor(order)# 获取群的所有元素def elements(self):return torch.linspace(start=0,end=2 * np.pi * (self.order - 1) / self.order,steps=self.order,device=self.identity.device)# 定义循环群中两个群元素的乘积def product(self, h, h_prime):# 循环群的群乘积可以通过模运算实现,这里使用向量加法product = (h + h_prime) % (2 * np.pi / self.order)return product# 定义循环群中群元素的逆def inverse(self, h):# 循环群中元素的逆是其补角inverse = -h % (2 * np.pi / self.order)return inverse# 定义群元素对R2空间向量的左作用def left_action_on_R2(self, h, x):transformed_x = torch.tensordot(self.matrix_representation(h), x, dims=1)return transformed_x# 获取群元素在R^2中的矩阵表示def matrix_representation(self, h):# 循环群的矩阵表示是一个旋转矩阵representation = torch.tensor([[np.cos(h), -np.sin(h)],[np.sin(h), np.cos(h)]])return representation# 将群元素的值归一化到-1和1之间def normalize_group_elements(self, h):largest_elem = 2 * np.pi * (self.order - 1) / self.ordernormalized_h = (2 * h / largest_elem) - 1return normalized_h# 测试代码,验证实现的正确性
c4 = CyclicGroup(order=4)
e, g1, g2, g3 = c4.elements() # 获取群的元素,包括单位元e和其它元素# 验证群的乘积和逆元
assert c4.product(e, g1) == g1 and c4.product(g1, g2) == g3
assert c4.product(g1, c4.inverse(g1)) == e# 验证单位元和特定元素的矩阵表示
assert torch.allclose(c4.matrix_representation(e), torch.eye(2))
assert torch.allclose(c4.matrix_representation(g2), torch.tensor([[-1, 0], [0, -1]]).float(), atol=1e-6)# 验证群元素对R2向量的左作用
assert torch.allclose(c4.left_action_on_R2(g1, torch.tensor([0., 1.])), torch.tensor([-1., 0.]), atol=1e-7)
代码实现了一个旋转群,特别地,它实现了4阶旋转群(C4),这是一个具有4个元素的群,通常用于表示90度的旋转。代码中包含了群元素的生成、群乘积、求逆、左作用以及矩阵表示等操作。最后,通过测试代码验证了实现的正确性。
2.4 可视化群操作
现在,我们将通过我们刚刚创建的群实现来实际操作一番!我们会在一些蔬菜图片上展示群的作用效果。为了从变换后的图像网格中获取像素值,我们将使用 PyTorch 的 grid_sample
函数(文档链接)。虽然你可以深入文档以详细了解该函数的工作原理,但简而言之,我们将使用它来进行双线性插值(对于二维数据)或多线性插值(对于更高维度),以便从旋转后的核网格中获取对应的权重值。接下来,我们将观察这些蔬菜图片在循环群 C 4 C_4 C4的作用下如何变换。
下面的代码演示了如何使用PyTorch和自定义的群类 CyclicGroup
来执行图像上的变换,特别是通过一个循环群的元素实现的图像旋转。
# 从磁盘加载图像。
img = Image.open(os.path.join(CHECKPOINT_PATH, "paprika.tiff"))# 将图像转换为torch张量。
img_tensor = transforms.ToTensor()(img)# 定义双线性插值函数,用于通过双线性插值获取一组网格点的信号值。
def bilinear_interpolation(signal, grid):# ...# 定义三线性插值函数,用于通过三线性插值获取一组网格点的信号值。
def trilinear_interpolation(signal, grid):# ...# 创建图像像素位置的网格,图像尺寸为[2, 512, 512]。
img_grid = torch.stack(torch.meshgrid(torch.linspace(-1, 1, img_tensor.shape[-2]),torch.linspace(-1, 1, img_tensor.shape[-1]),indexing='ij'
))# 创建一个90度顺时针旋转的循环群。
c4 = CyclicGroup(order=4)
e, g1, g2, _ = c4.elements()# 使用e, g1和g2创建一个270度逆时针旋转的元素g3。
## 你的代码从这里开始 ##
g3 = (g2 + g1) % (2 * np.pi / c4.order) # 270度旋转对应3个90度旋转的元素
## 你的代码在这里结束 ##assert g3 == c4.elements()[-1] # 确保g3是群的最后一个元素,即270度旋转# 使用群元素g3的逆矩阵对创建的图像网格进行变换。
transformed_img_grid = c4.left_action_on_R2(c4.inverse(g3), img_grid)# 在变换后的网格点上对图像进行采样。
transformed_img = bilinear_interpolation(img_tensor, transformed_img_grid)[0]# 将变换后的图像张量转换回PIL图像,以查看变换结果。
transforms.ToPILImage()(transformed_img)
代码中定义了两个插值函数 bilinear_interpolation
和 trilinear_interpolation
,分别用于二维和三维数据的插值操作。然后,代码创建了一个图像像素位置的网格 img_grid
,并初始化了一个4阶循环群 c4
,该群包含4个元素,分别对应0度、90度、180度和270度的旋转。
接下来,代码要求你计算群中对应270度逆时针旋转的元素 g3
。这可以通过组合群中的其他元素来实现,例如,两个90度旋转的元素相加,再对群的阶数取模,即可得到270度旋转的元素。
最后,代码使用群元素 g3
的逆矩阵对图像网格进行变换,并使用双线性插值在变换后的网格点上采样图像,得到变换后的图像 transformed_img
。最终,将这个张量转换为PIL图像对象,以便于可视化。
3. 群等变卷积神经网络
在构建群等变卷积神经网络时,我们通常需要考虑三个关键组件:提升卷积、群卷积以及投影操作。现在,我们将依次介绍这些组件。
3.1 提升卷积
提升卷积是第一步,它的作用是在输入特征图 f i n f_{in} fin 中任何空间位置,根据变换群 H H H 的作用,来“提升”或分离特征。这可以理解为,对于给定的特征 e e e,我们在所有位置注册该特征经过变换 h ∈ H h \in H h∈H 后的版本 L h ( e ) \mathcal{L}_h(e) Lh(e)(或有时用 h ⋅ e h \cdot e h⋅e 表示 h h h 对 e e e 的作用)。通过这个过程,提升卷积将特征从原始的 R 2 \mathbb{R}^2 R2 空间映射到一个扩展的空间 G = R 2 ⋊ H G = \mathbb{R}^2 \rtimes H G=R2⋊H,其中包含了额外的群维度(这些维度的数量取决于变换群 H H H 的维度)。因此,经过提升卷积后得到的特征映射 f o u t f_{out} fout 除了包含传统的空间维度外,还包含了一个或多个群维度。
以90度旋转的变换群 H H H为例,假设有一个特定的图案 e e e,它在我们的输入特征图 f i n f_{in} fin中出现了三次。其中两次是在90度旋转的状态,分别表示为 θ 90 ⋅ e \theta_{90}\cdot e θ90⋅e,而另一次则处于其原始(或标准)的方向,即 θ 0 ⋅ e \theta_0 \cdot e θ0⋅e。
当我们使用一个与图案 e e e完全匹配的卷积核 k k k进行卷积时,它将在特征图的群维度上产生不同的响应。具体来说,一个响应将出现在对应于群元素 θ 0 \theta_0 θ0的空间特征图位置上,而另外两个响应则会出现在对应于 θ 90 \theta_{90} θ90的空间特征图位置上。为了直观地理解这一点,可以参考上面提供的图形。
3.1.1. 概览
我们如何使卷积操作能够捕捉到在不同 H H H变换下的特征呢?从直觉上讲,这与我们在卷积神经网络(CNN)中熟悉的卷积操作并没有太大区别。在CNN中,我们通过在所有空间位置上共享相同的卷积核来提取任意位置的特征。类似地,这里的关键在于设计一个能够处理不同变换的卷积操作。
从群论的角度出发,我们可以将卷积操作解释为对卷积核 k k k应用所有可能的二维平移 x ∈ R 2 \boldsymbol{x} \in \mathbb{R}^2 x∈R2,然后计算输入特征图 f i n f_{in} fin与这些平移后的卷积核 L x ( k ) \mathcal{L}_{\mathbf{x}}(k) Lx(k)之间的内积,以获取响应。这个过程起始于一个在 R 2 \mathbb{R}^2 R2空间上定义的特征图,并最终产生另一个同样在 R 2 \mathbb{R}^2 R2空间上定义的特征图。
我们现在希望在不同的群动作 L h \mathcal{L}_h Lh 下为 h ∈ H h \in H h∈H 注册特征,我们可以通过将核 k k k 通过所有这些群动作进行变换并记录结果来实现。例如,在四阶循环群 C 4 {\rm C_4} C4 的情况下,我们不仅进行平移,还要将核 k k k 按照所有可能的90度旋转进行旋转,并记录这些变换核的结果响应!
3.1.2.实现提升卷积核
让我们开始编程。首先,我们需要定义一个可以在 L h \mathcal{L}_h Lh 下变换的核 k k k。在使用图像时,卷积核通常被定义为一组在 R 2 \mathbb{R}^2 R2 上等距离散化的独立采样权重 W W W(像素均匀分布)。
回顾一下,我们可以通过正则表示 L h \mathcal{L}_h Lh 来表达群 H H H 对定义在 R 2 \mathbb{R}^2 R2 上的函数(如核 k k k)的作用。正则表示通过变换函数 k k k 的定义域来变换函数 k k k。换句话说,正则表示变换了核 k k k 定义的网格,以获得变换后的函数 L h ( k ) \mathcal{L}_h(k) Lh(k) 的值。
因此,为了定义一个可以用群 H H H 的正则表示进行变换的核 k k k,我们需要构建一个定义核值的网格。然后,我们可以通过每个群元素 h ∈ H h \in H h∈H 的作用变换这个网格,得到对应于 H H H 的每个群元素的一组变换核的网格。让我们开始工作!
注意:* 在实现实际的提升和群卷积操作时,我们将利用 PyTorch 的 Conv2D 类。这大大简化了我们的工作,因为 Conv2D 会处理将核 k k k 移动到所有输入位置的工作。因此,我们不需要自己实现平移群 L x \mathcal{L}_{\mathbf{x}} Lx 的作用,但仍然保持平移等变性!要使操作与 Conv2D 兼容,需要一些技巧,但稍后会详细介绍。
下面的代码定义了一个名为 LiftingKernelBase
的类,它是一个用于提升卷积核的基类。这个类存储了定义在 R 2 \mathbb{R}^2 R2 上的提升核的网格以及在群 H H H的作用下的变换副本。
class LiftingKernelBase(torch.nn.Module):def __init__(self, group, kernel_size, in_channels, out_channels):"""构造函数,实现了提升核的基类。存储了定义在 R^2 上的提升核的网格以及在群 H 的作用下的变换副本。"""super().__init__()self.group = groupself.kernel_size = kernel_sizeself.in_channels = in_channelsself.out_channels = out_channels# 创建空间核网格。这些是我们的核权重定义的坐标。self.register_buffer("grid_R2", torch.stack(torch.meshgrid(torch.linspace(-1., 1., self.kernel_size),torch.linspace(-1., 1., self.kernel_size),indexing='ij')).to(self.group.identity.device))# 通过群的元素变换网格。self.register_buffer("transformed_grid_R2", self.create_transformed_grid_R2())def create_transformed_grid_R2(self):"""通过每个群元素的群作用变换创建网格。这将产生一个网格(在 H 上)的空间网格(在 R2 上)。换句话说,是一个网格列表,其中每个索引是原始空间网格通过 H 中相应的群元素变换得到的。"""# 获取所有群元素。## 你的代码从这里开始 ##group_elements = ...## 你的代码在这里结束 ### 使用采样的群元素变换定义在 R2 上的网格。# 回想一下左正则表示是如何作用在 R2 上的函数的域上的!# (提示:仔细看看 1.3 节下的方程)# 我们最终想要得到一个形状为 [2, |H|, kernel_size, kernel_size] 的网格。## 你的代码从这里开始 ##transformed_grid = ...## 你的代码在这里结束 ##return transformed_griddef sample(self, sampled_group_elements):"""为给定数量的群元素采样卷积核参数应包括::param sampled_group_elements: 要采样卷积核的群元素应返回::return kernels: 滤波器库,覆盖所有输入通道,包含为所有输出群元素变换的核。"""raise NotImplementedError()# 让我们检查我们的实现是否正确。首先我们检查变换网格的形状是否正确。
order = 4
lifting_kernel_base = LiftingKernelBase(group=CyclicGroup(order=order),kernel_size=7,in_channels=3,out_channels=1
)# 网格的形状应该是 [2, |H|, kernel_size, kernel_size]。
assert lifting_kernel_base.transformed_grid_R2.shape == torch.Size([2, 4, 7, 7])
plt.rcParams['figure.figsize'] = [12, 3]# 创建 [group_elements] 图像
fig, ax = plt.subplots(1, order)# 获取网格
transformed_grid_R2 = lifting_kernel_base.transformed_grid_R2# 可视化变换的核网格。我们在所有网格中用蓝色的 'x' 标记同一个角点作为参考点。
for group_elem in range(order):ax[group_elem].scatter(transformed_grid_R2[1, group_elem, :, :],transformed_grid_R2[0, group_elem, :, :],c='r')# 标记一个角点,这样我们就可以看到它的变化。ax[group_elem].scatter(transformed_grid_R2[1, group_elem, 0, 0],transformed_grid_R2[0, group_elem, 0, 0],marker='x',c='b')fig.text(0.5, 0., 'Group elements', ha='center')
plt.show()
代码中的 LiftingKernelBase
类初始化时,创建了一个基本的二维网格,这个网格代表了卷积核的权重坐标。然后,它通过 create_transformed_grid_R2
方法变换这个网格,以适应群 (H) 中每个元素的作用。sample
方法应该用于采样给定群元素的卷积核,但这个方法还没有实现(NotImplementedError
)。
测试部分创建了一个4阶循环群的 LiftingKernelBase
实例,并检查变换网格的形状是否正确。然后,使用 matplotlib
库可视化了变换的核网格,以确保它们正确地表示了群元素的作用。
如果你的代码实现无误,你应该能够观察到逆时针旋转的操作正在发生!
目前,我们已经有了一组在群 H H H作用下变换的网格。现在,我们需要决定如何在这些网格的每个网格点上采样卷积核的值。这是应用群卷积神经网络(GCNNs)时面临的首要挑战。
传统卷积神经网络(CNNs)之所以能在所有空间位置上共享相同的权重集,是因为我们仅通过匹配整像素距离的步长来平移卷积核。然而,对于任意群 H H H,我们可能需要获取在标准变换下位于卷积核像素网格之外的网格点的卷积核值。
虽然我们可以(并且实际上会)使用插值来获取像素位置之间的网格点的卷积核值,但这可能会限制模型的表达能力,并引入插值伪影。
注意事项:
当我们在实现90度旋转群 H = C 4 H=C_4 H=C4的等变性时,由于所有变换后的网格都共享相同的位置,我们当然可以避免使用插值。我们可以通过权重的重新排列来实现这个特定群的群操作。但在这个教程中,为了更具通用性,我们将采用插值的方法!幸运的是,PyTorch提供了一个函数,允许我们在网格上采样输入;我们将使用PyTorch的grid_sample
函数来进行插值!
以下的代码定义了一个名为 InterpolativeLiftingKernel
的类,继承自 LiftingKernelBase
类,用于创建一个插值提升卷积核。这个类初始化了一组权重,并通过插值生成变换后的空间卷积核。
class InterpolativeLiftingKernel(LiftingKernelBase):def __init__(self, group, kernel_size, in_channels, out_channels):super().__init__(group, kernel_size, in_channels, out_channels)# 创建并初始化一组权重,我们将通过插值来创建变换后的空间核。self.weight = torch.nn.Parameter(torch.zeros((self.out_channels,self.in_channels,self.kernel_size,self.kernel_size), device=self.group.identity.device))# 使用 kaiming 均匀初始化权重。torch.nn.init.kaiming_uniform_(self.weight.data, a=math.sqrt(5))def sample(self):"""为给定数量的群元素采样卷积核应返回::return kernels: 扩展到所有输入通道的滤波器库,包含为所有输出群元素变换的核。"""# 首先,我们将输出通道维度折叠到输入通道维度中;# 这允许我们使用 torch grid_sample 函数一次性变换整个滤波器库。## 你的代码从这里开始 ##weight = self.weight## 你的代码在这里结束 ### 采样变换后的核。transformed_weight = []for spatial_grid_idx in range(self.group.elements().numel()):transformed_weight.append(bilinear_interpolation(weight, self.transformed_grid_R2[:, spatial_grid_idx, :, :]))transformed_weight = torch.stack(transformed_weight)# 分离输入和输出通道。transformed_weight = transformed_weight.view(self.group.elements().numel(),self.out_channels,self.in_channels,self.kernel_size,self.kernel_size)# 将输出通道维度放在群维度之前。我们这样做# 是为了能够使用 PyTorch 的 Conv2D。细节见下文!transformed_weight = transformed_weight.transpose(0, 1)return transformed_weight# 实例化 InterpolativeLiftingKernel
ik = InterpolativeLiftingKernel(group=CyclicGroup(order=4),kernel_size=7,in_channels=2,out_channels=1
)# 采样权重
weights = ik.sample()
# 查看权重的形状
print(weights.shape)# 选择一个输出通道进行可视化
out_channel_idx = 0# 创建 [in_channels, group_elements] 图像
fig, ax = plt.subplots(ik.in_channels, ik.group.elements().numel())# 可视化每个输入通道和群元素下的权重
for in_channel in range(ik.in_channels):for group_elem in range(ik.group.elements().numel()):ax[in_channel, group_elem].imshow(weights[out_channel_idx, group_elem, in_channel, :, :].detach().numpy())# 添加标题和图例
fig.text(0.5, 0.04, 'Group elements', ha='center')
fig.text(0.04, 0.5, 'Input channels', va='center', rotation='vertical')plt.show()
在 InterpolativeLiftingKernel
类中,__init__
方法创建了一个可学习的权重参数 self.weight
,并使用 Kaiming 均匀初始化方法对其进行初始化。sample
方法通过双线性插值变换核权重,生成对应于群元素的变换核。
测试部分创建了一个 InterpolativeLiftingKernel
的实例,并调用 sample
方法来采样权重。然后,使用 matplotlib
库可视化了不同输入通道和群元素下的权重。
我们观察时,可以看到空间卷积核在旋转群元素的作用下发生了旋转!
2.1.3 实现提升卷积操作
现在,我们终于可以着手实现提升卷积操作了!这个类的主要功能是接收一个在 R 2 \mathbb{R}^2 R2空间上定义的特征图,并输出一个在 R 2 ⋊ H \mathbb{R}^2\rtimes H R2⋊H空间上的特征图。在这个输出中,不同变换 h ∈ H h \in H h∈H下的特征会沿着 H H H轴被区分开来。
注意事项:
为了避免重复造轮子并实现自己的卷积操作,我们决定利用PyTorch中高度优化的Conv2D类。为了达到这个目的,我们采用了一些巧妙的策略。通常,卷积层会在输入特征图上应用一组包含 n n n个空间卷积核的集合,其中 n n n是卷积操作的输出通道数。然而,现在我们还需要应用这些卷积核的num_group_elem个变换版本。为了简化操作,我们巧妙地让PyTorch将这些变换视为不同的输出通道。具体实现上,我们将原始的[out_channels, num_group_elem, in_channels, kernel_size, kernel_size]卷积核集合重塑为[out_channels * num_group_elem, in_channels, kernel_size, kernel_size]的集合。这样,PyTorch就能自动处理所有的变换版本了!
此外,使用PyTorch的Conv2D类还有一个额外的优点,那就是我们无需自己计算平移后的卷积核 L x ( k ) \mathcal{L}_{\mathbf{x}}(k) Lx(k),因为PyTorch已经为我们处理了这些细节。
以下代码定义了一个名为 LiftingConvolution
的类,它是一个执行提升卷积的 PyTorch 模块。这个类使用了之前定义的 InterpolativeLiftingKernel
类来生成变换后的卷积核,并通过 forward
方法应用提升卷积。
class LiftingConvolution(torch.nn.Module):def __init__(self, group, in_channels, out_channels, kernel_size, padding):super().__init__()# 初始化提升核,用于创建插值变换后的空间卷积核。self.kernel = InterpolativeLiftingKernel(group=group,kernel_size=kernel_size,in_channels=in_channels,out_channels=out_channels)# 填充参数self.padding = paddingdef forward(self, x):"""执行提升卷积@param x: 输入样本 [batch_dim, in_channels, spatial_dim_1, spatial_dim_2]@return: 群的同质空间上的函数[batch_dim, out_channels, num_group_elements, spatial_dim_1, spatial_dim_2]"""# 获取在群作用下变换的卷积核。## 你的代码从这里开始 ##conv_kernels = self.kernel.sample()## 你的代码在这里结束 ### 应用提升卷积。注意,使用 reshape 我们可以将核的群维度折叠到输出通道维度。# 我们将每个变换后的核视为一个额外的输出通道。这样我们可以利用 PyTorch 的 conv2d 函数!# 问题:你明白我们(可以)为什么这么做吗?## 你的代码从这里开始 ### 将 conv_kernels 重塑为适合 conv2d 的形状,并执行卷积x = torch.nn.functional.conv2d(x, conv_kernels.view(-1, conv_kernels.size(-3), conv_kernels.size(-2), conv_kernels.size(-1)),padding=self.padding)## 你的代码在这里结束 ### 重新塑形 [batch_dim, in_channels * num_group_elements, spatial_dim_1, spatial_dim_2]# 变为 [batch_dim, in_channels, num_group_elements, spatial_dim_1, spatial_dim_2],# 将通道和群维度分开。x = x.view(-1,self.kernel.out_channels,self.kernel.group.elements().numel(),x.shape[-1],x.shape[-2])return x# 实例化 LiftingConvolution
lifting_conv = LiftingConvolution(group=CyclicGroup(order=4),kernel_size=5,in_channels=3,out_channels=8,padding=False
)
在 LiftingConvolution
类中,__init__
方法初始化了 self.kernel
,它是一个 InterpolativeLiftingKernel
实例,用于生成变换后的卷积核。forward
方法执行提升卷积,首先通过调用 self.kernel.sample()
获取变换后的卷积核,然后使用 PyTorch 的 conv2d
函数应用这些核到输入 x
上。
注意,在 forward
方法中,我们使用 reshape
方法将变换后的卷积核的形状调整为 conv2d
函数所需的形状。然后,我们使用 view
方法重新排列输出的维度,以将群元素的维度与输出通道的维度分开。
最后,代码实例化了一个 LiftingConvolution
对象,指定了群的阶数为4,卷积核的大小为5,输入通道数为3,输出通道数为8,并设置填充为 False
。
当2.2 群卷积的实现
当我们谈及在群 G = R 2 ⋊ H G = \mathbb{R}^2 \rtimes H G=R2⋊H上执行卷积操作时,我们必须考虑到输入特征图 f i n f_{in} fin不仅包含定义在 R 2 \mathbb{R}^2 R2上的空间维度,还额外包含定义在群 H H H上的群维度。因此,为了设计一个能够处理这种复合结构的卷积层,我们需要定义一个群卷积核 k g r o u p k_{group} kgroup,它同样在群 G G G上定义。
3.2.群卷积
3.2.1. 群卷积的核心概念
与普通的空间卷积不同,群卷积涉及到与定义在群 G G G上的卷积核 k g r o u p k_{group} kgroup的交互。这个卷积核可以看作是一系列空间卷积核的集合,每个群元素 h ∈ H h \in H h∈H对应一个空间卷积核。重要的是,由于群 H H H的作用, k g r o u p k_{group} kgroup不仅会在空间上发生变化(如旋转),还会沿着群轴进行平移。
以 H = C 4 H = C_4 H=C4(即四元循环群)为例,当我们应用群元素 θ ∈ C 4 \theta \in C_4 θ∈C4时,这不仅会导致卷积核在空间域上的旋转,还会引起其在群轴上的平移。这种“扭曲-平移”的效应是群卷积的一个关键特性。
除了上述差异外,群卷积操作符的工作原理与提升操作符类似。我们再次利用群 H H H和 R 2 \mathbb{R}^2 R2的作用来变换群卷积核 k g r o u p k_{group} kgroup,并计算该核与输入之间的内积响应。虽然在实际操作中,平移群相关的计算工作由PyTorch自动处理,但整体流程保持了与提升操作的相似性。下面是对这一过程的直观描述。
3.2.2. 群卷积的步骤
在实现群卷积时,我们需要确保卷积操作能够同时处理空间维度和群维度。以下是一个简化的步骤概述:
-
定义群卷积核:首先,我们需要为群 G G G定义一个卷积核 k g r o u p k_{group} kgroup。这通常涉及为群 H H H的每个元素 h h h创建一个相应的空间卷积核。
-
重塑输入特征图:由于输入特征图包含群维度,我们需要确保它能够与群卷积核正确匹配。这可能需要我们对输入特征图进行重塑或扩展,以便它能够与群卷积核的维度相匹配。
-
执行群卷积:一旦我们有了群卷积核和适当重塑的输入特征图,我们就可以执行群卷积操作了。这通常涉及对输入特征图上的每个位置进行遍历,并使用相应的群卷积核进行卷积计算。
-
处理群轴上的平移:由于群 H H H的作用,我们还需要确保在群轴上正确处理平移。这可能需要我们在执行卷积操作时考虑到群元素的作用,并相应地调整卷积核或特征图的位置。
-
输出结果:最后,我们得到的是经过群卷积处理的输出特征图,它同样包含空间维度和群维度。这个输出特征图可以被用作后续层的输入,以进一步处理和分析数据。
3.2.3. 定义群卷积核
同样地,我们定义一个卷积核 k k k,该卷积核可以通过群作用进行变换。现在,我们的卷积核网格不仅定义在 R 2 \mathbb{R}^2 R2上,还额外定义在群 H H H上。
注意:
-
由于群 H H H上的网格由元素 h ′ ∈ H h' \in H h′∈H组成,使用(另一个)群元素 h ∈ H h \in H h∈H对 H H H上的网格进行变换,实际上就是应用 h h h与每个网格元素 h ′ h' h′的群乘积。
-
因为我们处理的是半直积群 R 2 ⋊ H \mathbb{R}^2 \rtimes H R2⋊H,我们可以在将网格组合成共享在 R 2 ⋊ H \mathbb{R}^2 \rtimes H R2⋊H上的网格之前,分别变换网格的 R 2 \mathbb{R}^2 R2和 H H H维度。
下面的代码定义了一个名为 GroupKernelBase
的类,它是一个用于群卷积核的基类。这个类存储了在群 R 2 ⋊ H R^2 \rtimes H R2⋊H上定义的网格以及在群 H H H 的所有元素作用下的变换副本。
class GroupKernelBase(torch.nn.Module):def __init__(self, group, kernel_size, in_channels, out_channels):"""构造函数,实现了群卷积核的基类。存储了在群 \( R^2 \rtimes H \) 上定义的网格以及在群 \( H \) 的所有元素作用下的变换副本。"""super().__init__()self.group = groupself.kernel_size = kernel_sizeself.in_channels = in_channelsself.out_channels = out_channels# 创建空间核网格self.register_buffer("grid_R2", torch.stack(torch.meshgrid(torch.linspace(-1., 1., self.kernel_size),torch.linspace(-1., 1., self.kernel_size),indexing='ij')).to(self.group.identity.device))# 核网格现在还扩展到群 H 上,因为我们的输入特征映射包含一个额外的群维度self.register_buffer("grid_H", self.group.elements())self.register_buffer("transformed_grid_R2xH", self.create_transformed_grid_R2xH())def create_transformed_grid_R2xH(self):"""通过 H 中每个群元素的群作用变换在 \( R^2 \rtimes H \) 上创建的网格。这产生了一组群上的网格。换句话说,是一个网格列表,其中每个索引是通过 H 中相应的群元素变换得到的 \( G \) 上的原始网格。"""# 采样群 H。## 你的代码从这里开始 ##group_elements = self.group.elements()## 你的代码在这里结束 ### 使用采样的群元素变换定义在 R2 上的网格。# 我们希望最终得到一个形状为 [2, |H|, kernel_size, kernel_size] 的网格。## 你的代码从这里开始 ##transformed_grid_R2 = self.group.left_action_on_R2(group_elements, self.grid_R2)## 你的代码在这里结束 ### 使用采样的群元素变换定义在 H 上的网格。我们想要一个形状为 [|H|, |H|] 的网格。# 确保像上面一样(在第一维上)堆叠变换。## 你的代码从这里开始 ##transformed_grid_H = group_elements.unsqueeze(0) # 假设 group_elements 是 [|H|] 形状## 你的代码在这里结束 ### 重新调整值到 -1 和 1 之间,我们这样做是为了满足 torch# grid_sample 函数的要求。transformed_grid_H = self.group.normalize_group_elements(transformed_grid_H)# 创建一个组合网格,作为 R2 和 H 上网格的乘积# 在群维度上重复 R2,并在空间维度上重复 H# 以创建一个形状为 [3, |H|, |H|, kernel_size, kernel_size] 的网格transformed_grid = torch.cat((transformed_grid_R2.view(2,group_elements.numel(),1,self.kernel_size,self.kernel_size,).repeat(1, 1, group_elements.numel(), 1, 1),transformed_grid_H.view(1,group_elements.numel(),group_elements.numel(),1,1,).repeat(1, 1, 1, self.kernel_size, self.kernel_size)),dim=0)return transformed_griddef sample(self, sampled_group_elements):"""为给定数量的群元素采样卷积核参数应包括::param sampled_group_elements: 要采样卷积核的群元素应返回::return kernels: 滤波器库,覆盖所有输入通道,包含为所有输出群元素变换的核。"""raise NotImplementedError()
在 GroupKernelBase
类中,__init__
方法创建了空间核网格,并存储了群 H H H的元素以及这些元素作用下网格的变换副本。create_transformed_grid_R2xH
方法变换了 R 2 ⋊ H R^2 \rtimes H R2⋊H 上的网格,产生了群 H H H上的一组网格。
sample
方法应该用于采样给定群元素的卷积核。
当我们对网格应用群 H H H的群作用时,让我们先直观地理解一下网格上发生了什么。首先,我们将观察群作用在 R 2 \mathbb{R}^2 R2上的效果。
在当前的设置中,群 H H H是一维的,正如我们所见,使用 H H H的所有群元素对 H H H上的网格进行变换,会导致网格在群上的平移。接下来,我们来看看当我们将这些网格组合起来时会发生什么。
下面的代码演示了如何使用 GroupKernelBase
类创建一个群卷积核的基础实例,并对其进行可视化。
# 创建 GroupKernelBase 的实例
k_base = GroupKernelBase(group=CyclicGroup(order=4),kernel_size=7,in_channels=1,out_channels=1
)# 获取变换后的网格形状
print(k_base.transformed_grid_R2xH.shape)# 设置绘图参数
plt.rcParams['figure.figsize'] = [10, 3]# 创建群元素数量的图形。
fig, ax = plt.subplots(1,k_base.group.elements().numel(),subplot_kw=dict(projection='3d')
)# 将空间和群网格维度展平。
transformed_grid_R2xH = k_base.transformed_grid_R2xH.reshape(3,k_base.group.elements().numel(),k_base.group.elements().numel() * k_base.kernel_size * k_base.kernel_size
)# 可视化变换的核网格。我们在所有网格中用蓝色的 'x' 标记同一行为参考点。
for group_elem in range(k_base.group.elements().numel()):ax[group_elem].scatter(transformed_grid_R2xH[1, group_elem, 1:], # X 坐标transformed_grid_R2xH[0, group_elem, 1:], # Y 坐标transformed_grid_R2xH[2, group_elem, 1:], # Z 坐标c='r')# 标记一个角点,这样我们就可以看到它的变换。ax[group_elem].scatter(transformed_grid_R2xH[1, group_elem, 0], # X 坐标transformed_grid_R2xH[0, group_elem, 0], # Y 坐标transformed_grid_R2xH[2, group_elem, 0], # Z 坐标marker='x',c='b'
)# 添加标题
fig.text(0.5, 0.04, 'Group elements', ha='center')# 显示图形
plt.show()
代码中首先创建了一个 GroupKernelBase
类的实例 k_base
,该实例使用了一个阶数为4的循环群 CyclicGroup
,卷积核大小为7,输入和输出通道数均为1。
接着,代码打印了变换后的网格 transformed_grid_R2xH
的形状,这有助于我们理解数据的组织方式。
然后,代码设置了绘图参数,并为每个群元素创建了一个子图,准备进行3D可视化。
在可视化循环中,代码使用 scatter
方法在3D空间中绘制了变换后的核网格的点,并用蓝色 ‘x’ 标记了每个网格中的一个参考点。
最后,使用 plt.show()
显示了整个图形。
我们可以看到,在应用群元素 h ′ ∈ H h' \in H h′∈H时,定义在 R 2 ⋊ H \mathbb{R}^2 \rtimes H R2⋊H上的网格会在空间维度上旋转,并在群维度上平移!
现在,让我们也使用插值来实现群卷积核。
注意:
- 对于多维群 H H H,以下实现将不起作用,因为这需要定义在维度大于3的网格上的卷积核,而我们的grid_sample的三线性插值实现不支持这一点。为了解决这个问题,可以沿着 H H H维度对权重矩阵进行平移来采样权重,而仅在空间维度上进行插值。这是可能的,因为我们在群维度 H H H上不会落在网格点之间(还记得群积的封闭性约束吗?)。
- C 4 C_4 C4群在群轴上具有周期性,因此我们的卷积核也应该具有周期性。尽管我们正确地实现了群积以反映这一点,但grid_sample在插值过程中并不知道权重的周期性。然而,由于群积的封闭性约束,我们应该始终在群轴上精确地落在网格点上,这意味着在该方向上不需要插值。在实践中,由于grid_sample的实现方式,我们可能会因为这一点而遇到一些小的插值伪影。
下面的代码定义了一个名为 InterpolativeGroupKernel
的类,它继承自 GroupKernelBase
类,并用于创建一组可以通过插值变换的卷积核。
class InterpolativeGroupKernel(GroupKernelBase):def __init__(self, group, kernel_size, in_channels, out_channels):super().__init__(group, kernel_size, in_channels, out_channels)# 创建并初始化一组权重,我们将通过插值来创建变换后的空间卷积核。# 注意,现在权重也扩展到群 H 上。## 你的代码从这里开始 ##self.weight = torch.nn.Parameter(torch.zeros((self.out_channels, self.in_channels, self.kernel_size, self.kernel_size),device=self.group.identity.device))## 你的代码在这里结束 ### 使用 kaiming 均匀初始化权重。torch.nn.init.kaiming_uniform_(self.weight.data, a=math.sqrt(5))def sample(self):"""为给定数量的群元素采样卷积核应返回::return kernels: 扩展到所有输入通道的滤波器库,包含为所有输出群元素变换的核。"""# 首先,我们将输出通道维度折叠到输入通道维度中;# 这允许我们使用插值函数一次性变换整个滤波器库。## 你的代码从这里开始 ##weight = self.weight.view(self.out_channels * self.in_channels,self.group.elements().numel(),self.kernel_size,self.kernel_size)## 你的代码在这里结束 ##transformed_weight = []# 我们遍历所有群元素,并检索对应于 R2xH 上变换网格的权重值。for grid_idx in range(self.group.elements().numel()):transformed_weight.append(trilinear_interpolation(weight, self.transformed_grid_R2xH[:, grid_idx, :, :, :]))transformed_weight = torch.stack(transformed_weight)# 分离输入和输出通道。transformed_weight = transformed_weight.view(self.group.elements().numel(),self.out_channels,self.in_channels,self.group.elements().numel(),self.kernel_size,self.kernel_size)# 将输出通道维度放在群维度之前。我们这样做# 是为了能够使用 PyTorch 的 Conv2D。细节见下文!transformed_weight = transformed_weight.transpose(0, 1)return transformed_weight# 实例化 InterpolativeGroupKernel
igk = InterpolativeGroupKernel(group=CyclicGroup(order=4),kernel_size=5,in_channels=2,out_channels=8
)# 采样权重
weights = igk.sample()
# 打印权重的形状
print(weights.shape)
在 InterpolativeGroupKernel
类中,__init__
方法创建了一个可学习的权重参数 self.weight
,并使用 Kaiming 均匀初始化方法对其进行初始化。sample
方法通过三线性插值变换核权重,生成对应于群元素的变换核。
在 sample
方法中,首先将权重重塑为适合插值函数的形状。然后,对于每个群元素,使用 trilinear_interpolation
函数插值变换后的权重。最后,将变换后的权重分离为输入和输出通道,并重新排列维度,以便使用 PyTorch 的 Conv2D
函数。
测试部分创建了一个 InterpolativeGroupKernel
的实例,并调用 sample
方法来采样权重。然后,打印出采样得到的权重的形状。
让我们可视化采样的群卷积核!通过将输入群维度折叠到第一个空间维度中,我们以二维方式展示三维卷积核。通过这种方式,我们创建了三维群卷积核的二维展平版本,其中对应于不同群元素的空间核沿着空间维度排列。每个通道从[num_group_elem, kernel_size, kernel_size]转变为[num_group_elem x kernel_size, kernel_size]。
为了清晰地看到群卷积核在群 H H H的变换下如何变化,我们用红色勾勒出与第一个输入群元素对应的空间核。在后续的变换中,我们可以看到这个空间核。请看下面的内容!
下面的代码是一个可视化部分,用于展示 InterpolativeGroupKernel
类生成的权重如何根据不同的群元素变换。
# 设置绘图参数
plt.rcParams['figure.figsize'] = [10, 10]# 为了方便查看,我们将输入群维度折叠到空间 x 维度
weights_t = weights.view(igk.out_channels,igk.group.elements().numel(),igk.in_channels,igk.group.elements().numel() * igk.kernel_size,igk.kernel_size
)# 选择一个输出通道进行可视化
out_channel_idx = 0# 创建 [in_channels, group_elements] 图像
fig, ax = plt.subplots(igk.in_channels, igk.group.elements().numel())# 对于每个输入通道和群元素,展示权重
for in_channel in range(igk.in_channels):for group_elem in range(igk.group.elements().numel()):ax[in_channel, group_elem].imshow(weights_t[out_channel_idx, group_elem, in_channel, :, :].detach())# 用红色边框标出标准变换下第一个群元素对应的空间核rect = matplotlib.patches.Rectangle((-0.5, group_elem * weights_t.shape[-1] - 0.5), weights_t.shape[-1], weights_t.shape[-1], linewidth=5, edgecolor='r', facecolor='none')ax[in_channel, group_elem].add_patch(rect)# 添加标题
fig.text(0.5, 0.04, 'Group elements', ha='center')
fig.text(0.04, 0.5, 'Input channels / input group elements', va='center', rotation='vertical')# 显示图形
plt.show()
代码中首先调整了绘图的尺寸,然后重新塑形权重张量 weights_t
,以便将输入群维度的信息合并到空间的 x 维度上。这使得我们可以更方便地查看每个群元素对应的权重变换。
接着,代码选择了一个输出通道(out_channel_idx = 0
)的权重进行可视化。它创建了一个图像网格,每个子图对应一个输入通道和群元素的组合。
在双重循环中,对于每个输入通道和群元素,代码使用 imshow
函数展示了权重的二维分布图。同时,使用 matplotlib.patches.Rectangle
创建了一个红色的矩形框,来标出标准变换下第一个群元素对应的空间核的位置。
最后,代码使用 plt.show()
显示了整个图形,使得我们可以直观地看到不同群元素作用下权重的变换情况。
我们看到与核网格相同的扭曲平移运动!
3.2.4. 定义群卷积
下一步是实现群卷积操作。
注意:
- 我们仍然希望使用PyTorch的Conv2D实现,但现在我们面临一个额外的问题;输入特征图中的群维度。幸运的是,我们可以用类似的方式解决这个问题;通常,二维卷积层会在所有输入通道的局部邻域上进行积分。我们现在还希望在整个群上进行积分。因此,我们可以简单地将输入特征图中的群维度视为额外的通道维度!我们通过将输入群维度折叠到输入通道维度中来实现这一点;(f_{in})的形状从[batch, in_channels, num_group_elem, spatial_1, spatial_2]重塑为[batch, in_channels x num_group_elem, spatial_1, spatial_2]。
- 为了匹配这一点,并应用与提升卷积中相同的技巧以获取每个单独群元素在输出中的结果,我们也将卷积核从[out_channels, num_group_elem, in_channels, num_group_elem, kernel_size, kernel_size]重塑为[out_channels x num_group_elem, in_channels x num_group_elem, kernel_size, kernel_size]。请参见下面的内容!
下面的代码定义了一个名为 GroupConvolution
的类,它是一个执行群卷积操作的 PyTorch 模块。这个类使用了之前定义的 InterpolativeGroupKernel
类来生成变换后的卷积核,并通过 forward
方法应用群卷积。
class GroupConvolution(torch.nn.Module):def __init__(self, group, in_channels, out_channels, kernel_size, padding):super().__init__()# 初始化群卷积核,用于创建插值变换后的空间卷积核。self.kernel = InterpolativeGroupKernel(group=group,kernel_size=kernel_size,in_channels=in_channels,out_channels=out_channels)# 填充参数self.padding = paddingdef forward(self, x):"""执行群卷积@param x: 输入样本 [batch_dim, in_channels, group_dim, spatial_dim_1, spatial_dim_2]@return: 群的同质空间上的函数[batch_dim, out_channels, num_group_elements, spatial_dim_1, spatial_dim_2]"""# 现在我们将输入的群维度折叠到输入通道维度中。## 你的代码从这里开始 ##x = x.view(x.size(0), -1, x.size(-2), x.size(-1)) # 折叠群维度到输入通道## 你的代码在这里结束 ### 我们获得在群作用下变换的卷积核。## 你的代码从这里开始 ##conv_kernels = self.kernel.sample() # 获取变换后的卷积核## 你的代码在这里结束 ### 应用群卷积,注意 reshape 将核的 '输出' 群维度折叠到输出通道维度,# 并将 '输入' 群维度折叠到输入通道维度。# 问题:你明白我们(可以)为什么这么做吗?## 你的代码从这里开始 ##x = torch.nn.functional.conv2d(x, conv_kernels, # 应用变换后的卷积核padding=self.padding)## 你的代码在这里结束 ### 重新塑形 [batch_dim, in_channels * num_group_elements, spatial_dim_1, spatial_dim_2]# 变为 [batch_dim, in_channels, num_group_elements, spatial_dim_1, spatial_dim_2],# 将通道和群维度分开。x = x.view(-1,self.kernel.out_channels,self.kernel.group.elements().numel(),x.shape[-1],x.shape[-2])return x
在 GroupConvolution
类中,__init__
方法初始化了 self.kernel
,它是一个 InterpolativeGroupKernel
实例,用于生成变换后的卷积核。forward
方法执行群卷积,首先通过 view
函数将输入 x
的群维度折叠到输入通道维度中,然后调用 self.kernel.sample()
获取变换后的卷积核。接着,使用 PyTorch 的 conv2d
函数应用这些核到输入 x
上。
在 forward
方法中,我们使用 reshape
方法重新排列输出的维度,以将群元素的维度与输出通道的维度分开。
最后,forward
方法返回群卷积的结果。
3.3. 定义GCNN
本节的主要内容是通过投影获得不变性并将所有内容整合在一起。
到目前为止,我们的特征图与 R 2 ⋊ H \mathbb{R}^2 \rtimes H R2⋊H的群作用等变;我们的特征图定义在 R 2 ⋊ H \mathbb{R}^2 \rtimes H R2⋊H上。在CNN中,一系列卷积层构建了一个表示,随后是一个(或多个)线性层。为了使用我们的提升和群卷积操作创建一个对群作用完全不变的GCNN,我们必须对我们的特征图应用一个对群作用不变的投影操作,以将其维度从[batch, channels, num_group_elem, spatial_1, spatial_2]降低到[batch, channels]或更低。这样获得的表示对群是完全不变的。然后,这个表示被送入最终的线性层以产生分类结果。
下面,我们将使用我们实现的PyTorch模块来构建一个小的GCNN。
注意:* 我们使用平均池化操作在群和空间维度上进行池化,但也可以使用最大池化、最小池化或任何对群不变的其他操作。
下面的代码定义了一个名为 GroupEquivariantCNN
的类,它是一个构建在 PyTorch 之上的群等变卷积神经网络模型。
from torch.nn import AdaptiveAvgPool3dclass GroupEquivariantCNN(torch.nn.Module):def __init__(self, group, in_channels, out_channels, kernel_size, num_hidden, hidden_channels):super().__init__()# 创建提升卷积。## 你的代码从这里开始 ##self.lifting_conv = GroupConvolution(group=group,in_channels=in_channels,out_channels=hidden_channels,kernel_size=kernel_size,padding=kernel_size // 2 # 假设使用零填充保持维度不变)## 你的代码在这里结束 ### 创建一组群卷积。self.gconvs = torch.nn.ModuleList()## 你的代码从这里开始 ##for i in range(num_hidden):self.gconvs.append(GroupConvolution(group=group,in_channels=hidden_channels,out_channels=hidden_channels,kernel_size=kernel_size,padding=kernel_size // 2 # 假设使用零填充保持维度不变))## 你的代码在这里结束 ### 创建投影层。提示:检查此单元格顶部的导入。## 你的代码从这里开始 ##self.projection_layer = torch.nn.AdaptiveAvgPool3d(1)## 你的代码在这里结束 ### 以及一个用于分类的最终线性层。self.final_linear = torch.nn.Linear(hidden_channels, out_channels)def forward(self, x):# 提升并解耦输入中的特征。x = self.lifting_conv(x)x = torch.nn.functional.layer_norm(x, x.shape[-4:])x = torch.nn.functional.relu(x)# 应用群卷积。for gconv in self.gconvs:x = gconv(x)x = torch.nn.functional.layer_norm(x, x.shape[-4:])x = torch.nn.functional.relu(x)# 为确保等变性,对群和空间维度应用最大池化。x = self.projection_layer(x).squeeze()x = self.final_linear(x)return x
在 GroupEquivariantCNN
类中,构造函数 __init__
初始化了模型的主要组件:
- 提升卷积 (
self.lifting_conv
):使用GroupConvolution
类创建,用于提取输入特征并将其从输入空间提升到同质空间。 - 群卷积层列表 (
self.gconvs
):使用torch.nn.ModuleList
存储多个GroupConvolution
层,用于在网络中进一步提取特征。 - 投影层 (
self.projection_layer
):使用AdaptiveAvgPool3d
类创建,用于将特征图的每个通道的尺寸压缩至 1x1x1,通常用于提高模型的等变性。 - 最终线性层 (
self.final_linear
):一个全连接层,用于将特征映射到最终的输出类别。
forward
方法定义了模型的前向传播过程:
- 输入
x
首先通过提升卷积,然后进行层归一化和ReLU激活。 - 输入接着通过多个群卷积层,每层后面都跟着层归一化和ReLU激活。
- 使用投影层对特征进行池化操作,并通过
squeeze
去除单维度。 - 最后,使用最终线性层将特征映射到输出类别。
注意:在实际使用中,需要确保 GroupConvolution
类已经正确定义,并且 InterpolativeGroupKernel
类能够生成适用于群卷积的核。此外,kernel_size // 2
作为填充参数是一种常见的实践,用于保持特征图的空间尺寸不变。
作为比较,我们可以创建一个几乎相同的CNN,唯一区别是下面这个网络由常规的卷积操作组成。
下面的代码定义了一个名为 CNN
的类,它是一个简单的卷积神经网络模型,使用 PyTorch 框架实现。
class CNN(torch.nn.Module):def __init__(self, in_channels, out_channels, kernel_size, num_hidden, hidden_channels):super().__init__()# 定义第一个卷积层self.first_conv = torch.nn.Conv2d(in_channels=in_channels, # 输入通道数out_channels=hidden_channels, # 输出通道数kernel_size=kernel_size, # 卷积核大小padding=0 # 填充为0,不进行填充)# 创建多个卷积层的列表self.convs = torch.nn.ModuleList()for i in range(num_hidden):self.convs.append(torch.nn.Conv2d(in_channels=hidden_channels, # 卷积层的输入通道数为前一层的输出通道数out_channels=hidden_channels,kernel_size=kernel_size,padding=0))# 定义最后的全连接层,用于分类self.final_linear = torch.nn.Linear(hidden_channels, out_channels)def forward(self, x):# 定义前向传播过程# 通过第一个卷积层x = self.first_conv(x)# 进行层归一化x = torch.nn.functional.layer_norm(x, x.shape[-3:])# 应用ReLU激活函数x = torch.nn.functional.relu(x)# 通过所有隐藏的卷积层for conv in self.convs:x = conv(x)x = torch.nn.functional.layer_norm(x, x.shape[-3:])x = torch.nn.functional.relu(x)# 应用平均池化,将特征图的每个通道压缩到1x1尺寸x = torch.nn.functional.adaptive_avg_pool2d(x, 1).squeeze()# 通过最后的全连接层x = self.final_linear(x)return x
在这个 CNN
类中:
__init__
方法初始化了网络的结构,包括第一个卷积层self.first_conv
,多个隐藏卷积层组成的列表self.convs
,以及最后的全连接层self.final_linear
。forward
方法定义了数据通过模型的流程。输入x
首先通过第一个卷积层,然后进行层归一化和ReLU激活。之后,输入通过所有隐藏卷积层,每层后面都跟着层归一化和ReLU激活。最后,使用adaptive_avg_pool2d
函数进行平均池化,将特征图尺寸压缩到1x1,并通过squeeze
去除单维度,然后通过全连接层输出最终结果。
这个模型是一个典型的卷积神经网络结构,适用于图像分类等任务。
4. 编程实践
4.1. 数据预处理
为了展示常规群卷积网络的泛化能力,我们将在MNIST训练数据集上训练这个模型,但在MNIST测试集的一个增强版本上进行评估,其中每个图像都被随机旋转了一个在 [ 0 , 2 π ] [0, 2\pi] [0,2π]之间的连续旋转角度。
# 对训练数据进行标准化处理。
train_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), # 将PIL图像或Numpy数组转换为Tensortorchvision.transforms.Normalize((0.1307,), (0.3081,)) # 对Tensor进行标准化,指定均值和标准差
])# 为了展示我们的旋转等变层带来的泛化能力,我们对测试集应用0到360度之间的随机旋转。
test_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.RandomRotation( # 对Tensor应用随机旋转[0, 360], # 随机旋转的角度范围torchvision.transforms.InterpolationMode.BILINEAR, # 指定插值方式为双线性插值fill=0 # 指定旋转后图像边缘的填充值),torchvision.transforms.Normalize((0.1307,), (0.3081,)) # 对旋转后的图像进行标准化
])# 我们在MNIST数据集上演示我们的模型。
train_ds = torchvision.datasets.MNIST(root=DATASET_PATH, # 指定数据集的根目录train=True, # 指定加载训练集transform=train_transform, # 应用训练集变换download=True) # 如果数据集不在本地,则下载数据集
test_ds = torchvision.datasets.MNIST(root=DATASET_PATH, # 指定数据集的根目录train=False, # 指定加载测试集transform=test_transform) # 应用测试集变换
train_loader = torch.utils.data.DataLoader(train_ds, # 创建训练集的数据加载器batch_size=64, # 指定批量大小shuffle=True) # 在每个epoch开始时打乱数据
test_loader = torch.utils.data.DataLoader(test_ds, # 创建测试集的数据加载器batch_size=64, # 指定批量大小shuffle=False) # 不打乱数据# 设置随机种子以确保结果的可复现性。
pl.seed_everything(12) # 设置随机种子
4.2.可视化数据集
# 设置展示图像的数量
NUM_IMAGES = 4# 从训练数据集中采样NUM_IMAGES个图像
images = [train_ds[idx][0] for idx in range(NUM_IMAGES)]# 将采样的图像从数据集中提取出来,并使用PIL库的fromarray方法转换成PIL图像
orig_images = [Image.fromarray(train_ds.data[idx].numpy()) for idx in range(NUM_IMAGES)]# 对原始图像应用测试集的变换
orig_images = [test_transform(img) for img in orig_images]# 将原始图像和变换后的图像堆叠起来,创建一个图像网格
# normalize=True表示标准化图像到[0,1]区间,pad_value=0.5表示网格之间的填充值为0.5
img_grid = torchvision.utils.make_grid(torch.stack(images + orig_images, dim=0), # 将图像列表堆叠成张量nrow=4, # 每行展示4个图像normalize=True, # 归一化pad_value=0.5 # 网格填充值
)# 调整图像网格的维度,以符合matplotlib的imshow要求
img_grid = img_grid.permute(1, 2, 0)# 创建一个图形,并设置其大小
plt.figure(figsize=(8, 8))# 设置标题为从MNIST训练集采样的图像,使用测试变换增强
plt.title("从MNIST训练集采样的图像,使用测试变换增强。")# 显示图像网格
plt.imshow(img_grid)# 关闭坐标轴
plt.axis('off')# 展示图形
plt.show()# 关闭图形
plt.close()
4.3.建立模型
定义了名为 DataModule
的类,它继承自 PyTorch Lightning 的 LightningModule
。这个类是用于构建 PyTorch Lightning 模型的模板,它提供了一种简便的方式来组织和训练深度学习模型。
class DataModule(pl.LightningModule):def __init__(self, model_name, model_hparams, optimizer_name, optimizer_hparams):"""构造函数输入:model_name - 要运行的模型/CNN的名称。用于创建模型(见下面的函数)model_hparams - 模型的超参数,作为字典。optimizer_name - 要使用的优化器名称。目前支持:Adam, SGDoptimizer_hparams - 优化器的超参数,作为字典。包括学习率、权重衰减等。"""super().__init__()# 将超参数导出到 YAML 文件,并创建 "self.hparams" 命名空间self.save_hyperparameters()# 创建模型self.model = create_model(model_name, model_hparams)# 创建损失模块self.loss_module = nn.CrossEntropyLoss()def forward(self, imgs):# 前向传播函数return self.model(imgs)def configure_optimizers(self):# 使用 AdamW 作为优化器,AdamW 是正确实现权重衰减的 Adam(详见:https://arxiv.org/pdf/1711.05101.pdf)optimizer = optim.AdamW(self.parameters(), **self.hparams.optimizer_hparams)return [optimizer], []def training_step(self, batch, batch_idx):# 训练步骤# "batch" 是训练数据加载器的输出imgs, labels = batchpreds = self.model(imgs) # 模型预测loss = self.loss_module(preds, labels) # 计算损失acc = (preds.argmax(dim=-1) == labels).float().mean() # 计算准确率# 在 tensorboard 中记录每个 epoch 的准确率(跨批次的加权平均)self.log('train_acc', acc, on_step=False, on_epoch=True)self.log('train_loss', loss)return loss # 返回张量以调用 ".backward" 进行反向传播def validation_step(self, batch, batch_idx):# 验证步骤imgs, labels = batchpreds = self.model(imgs).argmax(dim=-1) # 模型预测并取最大值作为预测类别acc = (labels == preds).float().mean() # 计算准确率# 默认情况下每个 epoch 记录一次(跨批次的加权平均),并在进度条中显示self.log('val_acc', acc, prog_bar=True)def test_step(self, batch, batch_idx):# 测试步骤imgs, labels = batchpreds = self.model(imgs).argmax(dim=-1)acc = (labels == preds).float().mean()# 默认情况下每个 epoch 记录一次(跨批次的加权平均),并在进度条中显示,并在之后返回self.log('test_acc', acc, prog_bar=True)
在这个 DataModule
类中:
__init__
方法接收模型名称、模型超参数、优化器名称和优化器超参数,并使用这些参数初始化模型和优化器。forward
方法定义了模型的前向传播。configure_optimizers
方法配置了模型训练过程中使用的优化器。这里使用了AdamW
,它是 Adam 优化器的一个变种,正确地实现了权重衰减。training_step
方法定义了单步训练的逻辑,包括计算损失和准确率,并记录训练过程中的准确率和损失。validation_step
和test_step
方法分别定义了验证和测试步骤的逻辑,它们计算并记录准确率。
这个类可以作为使用 PyTorch Lightning 框架进行模型训练的基础。通过继承这个类并实现特定的方法,可以方便地自定义模型训练、验证和测试过程。
定义创建和训练模型的函数,以及使用预训练模型的逻辑。
# 定义一个字典,包含不同的模型类
model_dict = {'CNN': CNN,'GCNN': GroupEquivariantCNN
}# 根据模型名称和超参数创建模型的函数
def create_model(model_name, model_hparams):if model_name in model_hparams:return model_dict[model_name](**model_hparams)else:assert False, f"未知的模型名称 \"{model_name}\"。可用的模型有:{str(model_dict.keys())}"
create_model
函数接收模型名称和模型的超参数字典,然后根据模型名称从 model_dict
字典中选择相应的模型类,并使用提供的超参数实例化它。如果模型名称不存在于字典中,则抛出异常。
# 训练模型的函数,可以加载预训练模型或从头开始训练
def train_model(model_name, save_name=None, **kwargs):"""输入:model_name - 要运行的模型名称。用于在 "model_dict" 中查找类save_name (可选) - 如果指定,这个名字将用于创建检查点和日志目录"""if save_name is None:save_name = model_name# 创建一个带有生成回调的 PyTorch Lightning 训练器trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, save_name), # 保存模型的位置accelerator='auto', # 我们在单个 GPU 上运行(如果可能)max_epochs=10, # 如果没有设置耐心,则训练的周期数callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"), # 根据记录的最大 val_acc 保存最佳检查点。只保存权重而不是优化器LearningRateMonitor("epoch")]) # 记录每个 epoch 的学习率trainer.logger._default_hp_metric = None # 可选的日志参数,我们不需要# 检查是否存在预训练模型。如果存在,加载它并跳过训练pretrained_filename = os.path.join(CHECKPOINT_PATH, save_name + ".ckpt")if os.path.isfile(pretrained_filename):print(f"在 {pretrained_filename} 发现预训练模型,正在加载...")model = DataModule.load_from_checkpoint(pretrained_filename) # 自动加载保存的模型和超参数else:pl.seed_everything(12) # 为了可复现性model = DataModule(model_name=model_name, **kwargs)trainer.fit(model, train_loader, test_loader)# 训练后加载最佳检查点model = DataModule.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)# 在测试集上测试最佳模型val_result = trainer.test(model.to(device), test_loader, verbose=False)result = {"val": val_result[0]["test_acc"]}return model, result
train_model
函数接收模型名称和其他关键字参数,使用 create_model
函数创建模型实例,并使用 PyTorch Lightning 的 Trainer
类进行训练。它还检查是否存在预训练模型的检查点,如果存在,则加载该模型而不是从头开始训练。训练完成后,它会在测试集上测试最佳模型,并返回模型和测试结果。
4.4.训练模型
4.4.1.训练CNN
# 训练名为 "CNN" 的模型,并将结果保存为 'cnn-pretrained'
cnn_model, cnn_results = train_model(model_name="CNN", # 指定模型名称为 "CNN"model_hparams={ # 模型的超参数字典"in_channels": 1, # 输入通道数为 1"out_channels": 10, # 输出通道数为 10,适用于10类分类问题"kernel_size": 5, # 卷积核大小为 5x5"num_hidden": 4, # 隐藏层的数量为 4"hidden_channels": 32 # 隐藏层的通道数为 32},optimizer_name="Adam", # 使用 Adam 优化器optimizer_hparams={ # 优化器的超参数字典"lr": 1e-2, # 学习率设置为 0.01"weight_decay": 1e-4, # 权重衰减设置为 0.0001,用于正则化},save_name='cnn-pretrained' # 保存模型的名称为 'cnn-pretrained'
)
代码通过 train_model
函数来启动模型训练流程。函数接收模型名称、模型和优化器的超参数以及保存名称作为输入。在这个例子中,我们训练一个简单的卷积神经网络(CNN),它有 1 个输入通道和 10 个输出通道,适用于分类任务。卷积层的核大小是 5x5,网络中有 4 个隐藏卷积层,每层有 32 个通道。
我们使用 Adam 优化器进行训练,设置学习率为 0.01,权重衰减为 0.0001,这有助于防止模型训练过程中的过拟合。
训练过程中,模型的检查点和日志将保存在与 save_name
同名的目录下。如果指定的检查点文件已存在,train_model
函数将加载预训练模型并跳过训练步骤。如果不存在,函数将初始化模型,进行训练,并在训练完成后评估模型性能。
最终,cnn_model
变量将包含训练完成的模型实例,而 cnn_results
字典将包含测试集上的准确率结果。
4.4.2.训练GCNN
紧接着,我们开始训练GCNN。为了平衡增加的卷积核维度,我们将通道数减半。这一举措旨在确保可训练参数的数量保持相对稳定。
# 训练名为 "GCNN" 的模型,并将结果保存为 'gcnn-pretrained'
gcnn_model, gcnn_results = train_model(model_name="GCNN", # 指定模型名称为 "GCNN"model_hparams={ # 模型的超参数字典"in_channels": 1, # 输入通道数为 1"out_channels": 10, # 输出通道数为 10,适用于10类分类问题"kernel_size": 5, # 卷积核大小为 5x5"num_hidden": 4, # 隐藏层的数量为 4"hidden_channels": 16, # 由于特征图中额外维度导致的可训练参数增加,减少隐藏层的通道数"group": CyclicGroup(order=4).to(device) # 定义群为阶数为4的循环群,并将群转换到相应的设备上},optimizer_name="Adam", # 使用 Adam 优化器optimizer_hparams={ # 优化器的超参数字典"lr": 1e-2, # 学习率设置为 0.01"weight_decay": 1e-4, # 权重衰减设置为 0.0001,用于正则化},save_name='gcnn-pretrained' # 保存模型的名称为 'gcnn-pretrained'
)
代码通过 train_model
函数来启动群卷积神经网络(GCNN)的训练流程。函数接收模型名称、模型和优化器的超参数以及保存名称作为输入。
在这个例子中,我们训练一个 GCNN,它有 1 个输入通道和 10 个输出通道,适用于分类任务。卷积层的核大小是 5x5,网络中有 4 个隐藏卷积层,每层有 16 个通道。由于 GCNN 在特征图中引入了额外的群维度,导致可训练参数数量增加,因此相比于普通 CNN,这里减少了隐藏层的通道数。
群卷积层需要一个群的定义,这里使用的是阶数为 4 的循环群,意味着群的元素对应于 0、90、180 和 270 度的旋转。这个群被转换到与模型相同的设备上(例如 CPU 或 GPU)。
我们使用 Adam 优化器进行训练,设置学习率为 0.01,权重衰减为 0.0001,这有助于防止模型训练过程中的过拟合。
训练过程中,模型的检查点和日志将保存在与 save_name
同名的目录下。如果指定的检查点文件已存在,train_model
函数将加载预训练模型并跳过训练步骤。如果不存在,函数将初始化模型,进行训练,并在训练完成后评估模型性能。
最终,gcnn_model
变量将包含训练完成的模型实例,而 gcnn_results
字典将包含测试集上的准确率结果。
现在,我们来分析一下两个模型的最终结果。显然,GCNN的表现显著优于CNN。这是因为GCNN对于90度的旋转是不变的,这使得它能够在经过此类旋转后仍然能够识别测试集中的手写数字。不过,由于测试图像是在0到360度之间进行了连续的随机旋转,GCNN模型仍然没有达到完美的准确率。那么,我们如何进一步提升GCNN的泛化能力呢?
4.5.观察生成的特征映射
为了更深入地理解在输入图像旋转时,CNN和GCNN的内部工作机制,我们可以查看网络第二层中某个特征映射的一个通道,该通道对应于输入图像经过不同旋转角度后的响应。
# 从MNIST数据集中加载测试数据集,不使用变换
train_ds = torchvision.datasets.MNIST(root=DATASET_PATH, train=False, transform=None)# 从测试数据集中获取一个图像。
digit, label = train_ds[123]# 将图像转换为张量。
digit = transforms.ToTensor()(digit)# 显示图像
plt.figure(figsize=(6, 6))
plt.imshow(digit.squeeze()) # 压缩图像维度,从[C, H, W]变为[H, W]
plt.title(f'Label: {label}') # 设置图像标题为标签
plt.show()# 设置绘图参数
plt.rcParams['figure.figsize'] = [10, 3]# 获取一组旋转角度,用于旋转图像
rots = torch.linspace(0, 360 - 360/8, 8) # 从0度到345度,每45度旋转一次# 旋转图像并应用标准化变换
rot_digit = torch.stack(tuple(torchvision.transforms.functional.rotate(digit, a.item(), torchvision.transforms.functional.InterpolationMode.BILINEAR) # 双线性插值旋转)for a in rots)
)
rot_digit = torchvision.transforms.Normalize((0.1307,), (0.3081,))(rot_digit) # 标准化旋转后的图像# 为每个旋转后的图像创建子图
fig, ax = plt.subplots(1, rots.numel())for idx, rotation in enumerate(rots):ax[idx].imshow(rot_digit[idx, :, :].squeeze()) # 显示旋转后的图像ax[idx].set_title(f"{int(rotation)} deg") # 设置子图标题为旋转角度fig.text(0.5, 0.04, 'Rotations of input image', ha='center') # 设置图的中心标题
plt.show()# 将图像输入CNN模型的前几层
cnn_out = cnn_model.model.first_conv(rot_digit)
cnn_out = torch.nn.functional.relu(torch.nn.functional.layer_norm(cnn_out, cnn_out.shape[-3:]))
for i in range(2):cnn_out = cnn_model.model.convs[i](cnn_out)cnn_out = torch.nn.functional.relu(torch.nn.functional.layer_norm(cnn_out, cnn_out.shape[-3:]))# 应用投影到剩余的空间维度,并压缩维度
projected_cnn_out = torch.nn.functional.adaptive_avg_pool2d(cnn_out, 1).squeeze()# 将图像输入GCNN模型的前几层
gcnn_out = gcnn_model.model.lifting_conv(rot_digit)
gcnn_out = torch.nn.functional.relu(torch.nn.functional.layer_norm(gcnn_out, gcnn_out.shape[-4:]))
for i in range(2):gcnn_out = gcnn_model.model.gconvs[i](gcnn_out)gcnn_out = torch.nn.functional.relu(torch.nn.functional.layer_norm(gcnn_out, gcnn_out.shape[-4:]))# 应用投影到GCNN模型的等变表示
projected_gcnn_out = torch.mean(gcnn_out, dim=(-3, -2, -1)) # 计算GCNN输出在剩余空间维度上的平均值
代码首先从MNIST数据集中获取一个图像,并对其进行一系列操作,包括转换为张量、显示图像、旋转图像、标准化处理、可视化旋转效果。接着,代码演示了如何将图像输入到CNN和GCNN模型的前几层,并观察经过每一层后的输出效果。对于CNN,使用了ReLU激活函数和层归一化。对于GCNN,除了ReLU和层归一化外,还使用了提升卷积。最后,代码展示了如何对CNN和GCNN的输出应用投影操作,以观察模型对输入图像的变换。
5.总结和展望
5.1 总结
本文详细介绍了卷积神经网络(CNN)和图卷积网络(GCNN)的理论知识和实现方法。我们首先回顾了CNN中的关键概念,包括卷积操作、激活函数、池化操作等,并讨论了如何在PyTorch框架中实现这些操作。接着,我们深入探讨了GCNN的理论基础,包括群论、群卷积核的定义和实现,以及如何在PyTorch Lightning框架中构建和训练GCNN模型。
在实现部分,我们提供了一个名为DataModule
的类,它简化了模型的创建、训练和测试过程。我们还展示了如何使用train_model
函数来训练CNN和GCNN模型,并如何在MNIST数据集上评估它们的性能。此外,我们还探讨了如何观察和分析模型生成的特征映射,以更好地理解模型在处理图像旋转时的内部机制。
实验结果表明,GCNN在处理经过旋转的图像时,相较于CNN展现出了更好的泛化能力。这得益于GCNN的等变性特性,使其能够捕捉到图像的不变特征。我们还讨论了如何通过投影操作获得对群作用不变的表示,进一步提高模型的泛化性能。
5.2 展望
尽管本文全面探讨了CNN和GCNN的理论基础与实现,未来研究仍有多个有前景的发展方向:
5.2.1 模型性能与应用范围的扩展
首先,尽管GCNN在处理旋转图像方面显示出优越的泛化能力,但其在其他图像变换(如缩放、剪切)下的性能仍需提升。此外,本文的实验集中在MNIST数据集上,未来的研究应考虑更广泛的数据集和应用场景,包括自然图像分类和医学图像分析等,以全面评估和提升CNN和GCNN的实用性和泛化能力。
5.2.2 模型优化与理论深化
其次,研究不同的优化器和正则化技术对模型性能的影响至关重要。本文提到的AdamW和权重衰减是提高模型性能的实例,但持续探索新的优化策略和正则化方法将有助于进一步提升模型效果。同时,构建更深层次和复杂度的网络结构,可以增强模型处理复杂数据的能力。深入的理论研究和模型解释性分析也将为理解模型决策过程提供支持。
5.2.3 计算效率与跨领域探索
最后,提高模型的训练和推理效率是实现在资源受限环境部署的关键。此外,将CNN和GCNN拓展到其他领域,如自然语言处理和强化学习,可以开拓新的应用前景并进一步探索其潜力。通过这些跨学科的应用,可以不断推动深度学习领域的创新和发展。