import torch
import torch.nn.functional as F
from torch import nn
import math
- 这几行代码导入了PyTorch库,包括主要的
torch
模块、torch.nn.functional
模块(通常用于激活函数等),torch.nn
模块(用于定义神经网络组件),以及math
模块(用于基本的数学运算)。device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- 这行代码用于检测是否有可用的CUDA设备(即GPU),如果有,则使用GPU进行计算,否则使用CPU。
class PatchEmbedding(nn.Module):
- 这行代码定义了一个名为
PatchEmbedding
的类,它继承自torch.nn.Module
。nn.Module
是所有神经网络模块的基类,提供了很多用于构建网络的方法和属性。"""
Path embedding layer is nothing but a convolutional layer with kerneli size and stride equal to patch size.
"""
- 这是一个类级别的文档字符串(docstring),解释了
PatchEmbedding
类的功能:它实际上是一个卷积层,其卷积核大小和步长等于patch的大小。def __init__(self, in_channels, embedding_dim, patch_size):
- 这是
PatchEmbedding
类的构造函数,它接受三个参数:in_channels
(输入图像的通道数),embedding_dim
(表征的维度,即卷积层输出的通道数),patch_size
(patch的大小,同时也是卷积核的大小和步长)。super().__init__()
- 这行代码调用父类
nn.Module
的构造函数,是面向对象编程中的常规做法。self.patch_embedding = nn.Conv2d(
in_channels, embedding_dim, patch_size, patch_size
)
- 这行代码定义了一个二维卷积层
patch_embedding
,其输入通道数为in_channels
,输出通道数为embedding_dim
,卷积核大小和步长都为patch_size
。def forward(self, x):
- 这是
PatchEmbedding
类的forward
方法,它定义了如何执行前向传播。x
是输入数据。return self.patch_embedding(x)
- 这行代码执行了之前定义的
patch_embedding
卷积层的前向传播,并将结果返回。这样,输入图像就被转换成了patch表征。class KANLinear(nn.Module):
def __init__(
self,
in_features, # 输入特征的维度
out_features, # 输出特征的维度
grid_size=5, # B样条网格的大小,默认为5
spline_order=3, # B样条的阶数,默认为3
):
super().__init__() # 调用父类nn.Module的构造函数
self.in_features = in_features # 保存输入特征的维度
self.out_features = out_features # 保存输出特征的维度
self.grid_size = grid_size # 保存网格大小
self.spline_order = spline_order # 保存B样条的阶数
# Calculate the grid step size
grid_step = 2 / grid_size # 计算网格的步长,这里假设网格覆盖的区间是[-1, 1]
# Create the grid tensor
grid_range = torch.arange(-spline_order, grid_size + spline_order + 1)# 生成一个从-spline_order到grid_size+spline_order的序列
grid_values = grid_range * grid_step - 1# 将序列映射到[-1, 1]区间内,步长为grid_step
self.grid = grid_values.expand(in_features, -1).contiguous()# 将grid_values扩展为(in_features, grid_size+2*spline_order+1)的形状,并确保内存连续
# 初始化权重参数
self.base_weight = nn.Paramete