KAN 学习 Day4 —— MultKAN 正向传播代码解读及测试

在KAN学习Day1——模型框架解析及HelloKAN中,我对KAN模型的基本原理进行了简单说明,并将作者团队给出的入门教程hellokan跑了一遍;

在KAN 学习 Day2 —— utils.py及spline.py 代码解读及测试中,我对项目的基本模块代码进行了解释,并以单元测试的形式深入理解模块功能,其中还发现了一个细小的错误。

在KAN 学习 Day3 —— KANLayer.py 与 Symbolic_KANLayer.py 代码解读及测试中,我对两种KAN层的实现进行了解读,它们分别是 “基于B样条曲线的KAN层” 和 “基于 eq?c*f%28a*x+b%29+d 的KAN层” 。(在下文中就称 B样条KAN层 和 符号KAN层)

今天我们开始对完整的KAN网络进行剖析,根据之前的经验,MultKAN类应该包括网络初始化、层之间网格参数传递、反向传播参数更新、网络剪枝、画图等等操作。

目录

一、kan目录

二、MultKAN.py 

2.1 类注释

​​​​2.2 构造函数 __init__

2.3 节点数计算

2.4 前向传播 forward

 0. 方法定义及注释

1. 初始化阶段

2. 前向传播循环

2.5 训练方法 fit

三、总结


一、kan目录

kan目录结构如下,包括了模型源码、检查点、实验以及assets等

e12295be65d94b3381e647242dc51eba.pngcc0f7d4a5a5148c995fcd44ad9bbbab6.png

 先了解一下这些文件/文件夹的大致信息:

  • kan\__init__.py:用于初始化Python包,方便使用时导入模块
  • kan\compiler.py:用于编译模型
  • kan\experiment.py:实验代码
  • kan\feynman.py:费曼函数,根据传入“name”的值确定函数,暂时没找到这个在哪里用到
  • kan\hypothesis.py:将函数进行线性分离,还包含一些画图函数
  •  kan\KANLayer.py:KAN层的实现,使用B样条曲线作为激活函数 
  • kan\LBFGS.py:这个文件名似乎昨天见过,训练时的opt参数。L-BFGS是一种用于无约束优化问题的算法,它是一种拟牛顿方法,特别适用于大型稀疏问题。
  • kan\MLP.py:作者自己实现了一个MLP,应该使来与KAN做对比的
  • kan\MultKAN.py:在KANLayer的基础上实现的KAN类的定义,提供了关于构建和配置这种网络的详细信息。
  • kan\spline.py:样条函数的实现
  •  kan\Symbolic_KANLayer.py:符号化的KAN层,使用四参线性函数作为激活函数 
  • kan\utils.py:通用模块
  • kan\.ipynb_checkpoints:看目录名,这个文件夹下存放的应该是检查点文件,但是似乎和模型的实现代码区别不大,没遇到过,还不知道有什么用。
  • kan\assets:这个目录下存放了两张图片,一张加号一张乘号,应该是对函数进行线性分离后,可视化时用的
  • kan\experiments:这个目录下是experiment1.ipynb,和昨天跑的hellokan差不多,今天再跑一下

二、MultKAN.py 

import torch
import torch.nn as nn
import numpy as np
from .KANLayer import KANLayer
#from .Symbolic_MultKANLayer import *
from .Symbolic_KANLayer import Symbolic_KANLayer
from .LBFGS import *
import os
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import copy
#from .MultKANLayer import MultKANLayer
import pandas as pd
from sympy.printing import latex
from sympy import *
import sympy
import yaml
from .spline import curve2coef
from .utils import SYMBOLIC_LIB
from .hypothesis import plot_tree

导入的这些依赖中,只有 LBFGS 和 plot_tree 我们还没介绍,这两部分内容我也没打算深入研究

  • LBFGS(Limited-memory BFGS)是一种优化算法,它主要用于求解无约束优化问题。
  • plot_tree则是画出网络的树状图

2.1 类注释

class MultKAN(nn.Module):'''KAN classAttributes:-----------grid : intthe number of grid intervalsk : intspline orderact_fun : a list of KANLayerssymbolic_fun: a list of Symbolic_KANLayerdepth : intdepth of KANwidth : listnumber of neurons in each layer.Without multiplication nodes, [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons.With multiplication nodes, [2,[5,3],[5,1],3] means besides the [2,5,53] KAN, there are 3 (1) mul nodes in layer 1 (2). mult_arity : int, or list of int listsmultiplication arity for each multiplication node (the number of numbers to be multiplied)grid : intthe number of grid intervalsk : intthe order of piecewise polynomialbase_fun : funresidual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)symbolic_fun : a list of Symbolic_KANLayerSymbolic_KANLayerssymbolic_enabled : boolIf False, the symbolic front is not computed (to save time). Default: True.width_in : listThe number of input neurons for each layerwidth_out : listThe number of output neurons for each layerbase_fun_name : strThe base function b(x)grip_eps : floatThe parameter that interpolates between uniform grid and adaptive grid (based on sample quantile)node_bias : a list of 1D torch.floatnode_scale : a list of 1D torch.floatsubnode_bias : a list of 1D torch.floatsubnode_scale : a list of 1D torch.floatsymbolic_enabled : boolwhen symbolic_enabled = False, the symbolic branch (symbolic_fun) will be ignored in computation (set to zero)affine_trainable : boolindicate whether affine parameters are trainable (node_bias, node_scale, subnode_bias, subnode_scale)sp_trainable : boolindicate whether the overall magnitude of splines is trainablesb_trainable : boolindicate whether the overall magnitude of base function is trainablesave_act : boolindicate whether intermediate activations are saved in forward passnode_scores : None or list of 1D torch.floatnode attribution scoreedge_scores : None or list of 2D torch.floatedge attribution scoresubnode_scores : None or list of 1D torch.floatsubnode attribution scorecache_data : None or 2D torch.floatcached input dataacts : None or a list of 2D torch.floatactivations on nodesauto_save : boolindicate whether to automatically save a checkpoint once the model is modifiedstate_id : intthe state of the model (used to save checkpoint)ckpt_path : strthe folder to store checkpointsround : intthe number of times rewind() has been calleddevice : str'''

这段代码定义了一个名为 MultKAN 的类,它是基于 nn.Module 构建的,这个类具有众多的属性,用于描述和控制其行为和特征:

  • grid:网格的间隔数(使用网格进行参数优化)
  • k:分段多项式的阶数,或者说B样条的控制点数
  • act_fun:B样条KAN层列表
  • symbolic_fun:符号KAN层列表。
  • depth:表示模型的深度。
  • width:描述了各层神经元的数量。
  • mult_arity:与乘法节点的乘法运算的元数有关。
  • base_fun:公式中的eq?b%28x%29
  •  symbolic_enabled:布尔值,是否使用符号KAN层 
  • width_in 和 width_out:分别表示各层的输入和输出神经元数量。
  • base_fun_name:基础函数的名称。
  • grip_eps:可能用于在均匀网格和自适应网格之间进行插值。
  • 各种与偏差、缩放、训练相关的属性,如 node_biasnode_scale 等,用于控制模型的训练和参数调整。
  • 各种与分数、缓存、自动保存、设备等相关的属性,用于模型的评估、数据存储、模型保存和硬件设置等方面。 

嘛,就是说这里好多注释又重复了......

​​​​2.2 构造函数 __init__

    def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu'):'''initalize a KAN modelArgs:-----width : list of intWithout multiplication nodes: :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs)With multiplication nodes: :math:`[[n_0,m_0=0], [n_1,m_1], .., [n_{L-1},m_{L-1}]]` specify the number of addition/multiplication nodes in each layer (including inputs/outputs)grid : intnumber of grid intervals. Default: 3.k : intorder of piecewise polynomial. Default: 3.mult_arity : int, or list of int listsmultiplication arity for each multiplication node (the number of numbers to be multiplied)noise_scale : floatinitial injected noise to spline.base_fun : strthe residual function b(x). Default: 'silu'symbolic_enabled : boolcompute (True) or skip (False) symbolic computations (for efficiency). By default: True. affine_trainable : boolaffine parameters are updated or not. Affine parameters include node_scale, node_bias, subnode_scale, subnode_biasgrid_eps : floatWhen grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.grid_range : list/np.array of shape (2,))setting the range of grids. Default: [-1,1]. This argument is not important if fit(update_grid=True) (by default updata_grid=True)sp_trainable : boolIf true, scale_sp is trainable. Default: True.sb_trainable : boolIf true, scale_base is trainable. Default: True.device : strdeviceseed : intrandom seedsave_act : boolindicate whether intermediate activations are saved in forward passsparse_init : boolsparse initialization (True) or normal dense initialization. Default: False.auto_save : boolindicate whether to automatically save a checkpoint once the model is modifiedstate_id : intthe state of the model (used to save checkpoint)ckpt_path : strthe folder to store checkpoints. Default: './model'round : intthe number of times rewind() has been calleddevice : strReturns:--------selfExample------->>> from kan import *>>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)checkpoint directory created: ./modelsaving model version 0.0'''

这段代码是 MultKAN 类的构造函数 __init__ 的定义。构造函数用于初始化一个 MultKAN 模型实例,并为其设置各种参数。

参数说明:

  • width:一个整数列表,指定了每一层的神经元数量。如果没有乘法节点,列表中的每个元素代表相应层的神经元数量;如果有乘法节点,列表中的元素是一个包含神经元数量和乘法节点数量的元组。
  • grid:网格间隔的数量,默认为3。
  • k:分段多项式的阶数,默认为3。
  • mult_arity:每个乘法节点的乘法运算的元数,可以是单个整数或整数列表。
  • noise_scale:注入到样条函数中的初始噪声的缩放比例。
  • scale_base_mu 和 scale_base_sigma:基础函数的缩放参数的均值和标准差。
  • base_fun:残差函数 b(x) 的类型,默认为 'silu'。
  • symbolic_enabled:是否启用符号计算,默认为True。
  • affine_trainable:是否更新仿射参数,包括节点缩放、节点偏差、子节点缩放和子节点偏差。
  • grid_eps:用于在均匀网格和自适应网格之间进行插值的参数。
  • grid_range:设置网格范围的列表或NumPy数组。
  • sp_trainable:如果为真,则spline的缩放是可训练的。
  • sb_trainable:如果为真,则基础函数的缩放是可训练的。
  • device:指定设备,如 'cpu' 或 'cuda'。
  • seed:随机种子,用于初始化权重。
  • save_act:指示是否在正向传递中保存中间激活。
  • sparse_init:是否进行稀疏初始化。
  • auto_save:指示是否在修改模型后自动保存检查点。
  • state_id:模型的当前状态,用于保存检查点。
  • ckpt_path:存储检查点的文件夹路径。
  • roundrewind() 被调用次数。
  • device:设备类型。

代码说明:

        super(MultKAN, self).__init__()
  • 调用父类 MultKAN 的初始化方法,用于设置一些基本的属性或执行一些初始化操作。
        torch.manual_seed(seed)np.random.seed(seed)random.seed(seed)
  • 这三行代码设置了随机数种子,确保每次运行代码时生成的随机数序列相同,这对于测试和调试非常有用。据说将seed设置为3407会将模型的性能提升1%

        ### initializeing the numerical front ###self.act_fun = []self.depth = len(width) - 1
  • 这里初始化了激活函数列表 self.act_fun 和模型的深度 self.depth,深度是通过宽度列表的长度减一得到的。
        for i in range(len(width)):if type(width[i]) == int:width[i] = [width[i],0]self.width = width
  • 遍历宽度列表,如果宽度为整数,则将其转换为列表形式,形式为 [宽度, 0]。
  • 将宽度列表赋值给 self.width 属性。
  • 注意到,注释中的width属性是有两种形式的,这几行代码使其都转化为了第二种形式,即如果有乘法节点,列表中的元素是一个包含神经元数量和乘法节点数量的元组。
        # if mult_arity is just a scalar, we extend it to a list of lists# e.g, mult_arity = [[2,3],[4]] means that in the first hidden layer, 2 mult ops have arity 2 and 3, respectively;# in the second hidden layer, 1 mult op has arity 4.if isinstance(mult_arity, int):self.mult_homo = True # when homo is True, parallelization is possibleelse:self.mult_homo = False # when home if False, for loop is required. self.mult_arity = mult_arity
  • 如果 mult_arity 是一个标量(即单个数字),代码将把它扩展为一个列表的列表。这样做通常是为了将单一的参数应用到多个乘法操作上。
  • 例如,如果 mult_arity = [[2,3],[4]],这意味着在第一个隐藏层中有两个乘法操作,它们的参数分别是 2 和 3;在第二个隐藏层中有一个乘法操作,其参数是 4。
  • 这里检查 mult_arity 是否是一个整数。如果是,那么所有乘法操作的参数都是相同的,这意味着它们是同质的。在这种情况下,可以将这些操作并行化,以提高计算效率。因此,将 self.mult_homo 设置为 True
  • 如果 mult_arity 不是一个整数,那么它可能是一个列表的列表,其中包含不同层级的不同参数。在这种情况下,不能并行化乘法操作,因为每个操作的参数可能不同。因此,将 self.mult_homo 设置为 False,这意味着可能需要使用循环来处理每个操作。
  • 最后,将处理后的 mult_arity 参数赋值给 self.mult_arity,这样模型就可以使用这个参数来定义其乘法操作了。
        width_in = self.width_inwidth_out = self.width_out

调用了两个方法,获得了KAN层真正的输入输出节点数。

        self.base_fun_name = base_funif base_fun == 'silu':base_fun = torch.nn.SiLU()elif base_fun == 'identity':base_fun = torch.nn.Identity()elif base_fun == 'zero':base_fun = lambda x: x*0.
  • 将传入的 base_fun 参数赋值给实例变量 self.base_fun_name。这意味着 base_fun 是一个字符串,它表示想要使用的基础函数的名称。
  • 如果 base_fun_name 是字符串 'silu',那么代码将创建一个 torch.nn.SiLU() 对象。SiLU(Sigmoid-weighted Linear Unit)是一个激活函数,通常用于神经网络中。
  • 如果 base_fun_name 是字符串 'identity',那么代码将创建一个 torch.nn.Identity() 对象。Identity 函数是一个恒等函数,它直接返回其输入值,通常用作默认激活函数或不改变输入的层。
  • 如果 base_fun_name 是字符串 'zero',那么代码将创建一个匿名函数(lambda 函数),这个函数将任何输入 x 乘以 0,从而输出 0。这可能表示一个“关闭”激活状态的函数,不激活任何神经元。
        self.grid_eps = grid_epsself.grid_range = grid_range
  • 将网格相关的参数赋值给 self.grid_eps 和 self.grid_range
  • grid_eps:控制网格细化策略的浮点数,默认为0.02。当 grid_eps = 1 时,网格是均匀的;当 grid_eps = 0 时,它使用样本的百分位数进行分区。0 < grid_eps < 1 插值在两种极端之间。
        for l in range(self.depth):# splinessp_batch = KANLayer(in_dim=width_in[l], out_dim=width_out[l+1], num=grid, k=k, noise_scale=noise_scale, scale_base_mu=scale_base_mu, scale_base_sigma=scale_base_sigma, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, sparse_init=sparse_init)self.act_fun.append(sp_batch)
  • 这段代码在循环中为每一层创建一个 KANLayer 实例,并把这些实例添加到一个列表中,以便后续可以用于KAN网络模型。
        self.node_bias = []self.node_scale = []self.subnode_bias = []self.subnode_scale = []
  •  初始化用于节点和子节点的偏差和缩放参数的列表。
        globals()['self.node_bias_0'] = torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False)exec('self.node_bias_0' + " = torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False)")
  • globals() 返回当前全局命名空间中的所有全局变量。
  • self.node_bias_0 是类的一个属性,这个属性在类的定义中尚未明确定义(即,它不是类的内部成员,而是通过全局命名空间访问的)。
  • torch.nn.Parameter(torch.zeros(3,1)) 创建了一个PyTorch的参数(Parameter对象),该对象是张量,用于在神经网络中存储权重,并且支持梯度计算。
  • .requires_grad_(False) 设置了该参数对象不进行梯度计算,即不会追踪其在计算图中的操作,这对于不需要计算梯度的参数(如偏置项)来说是合理的。
  • exec() 是Python的内置函数,用于执行字符串形式的Python代码。在这里,它被用来动态地创建或更新类属性。
  • 'self.node_bias_0' 是一个字符串,表示类中要创建或更新的属性名。
  • torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False) 是一个字符串表达式,创建了一个新的PyTorch参数并设置了其梯度计算为False。

这种做法在某些情况下非常有用,比如在定义神经网络模型时,需要动态地为特定的参数创建属性,或者在模型中为某些不需要梯度计算的参数(如偏置项)创建独立的属性。

但是!我没找到这两行代码有啥用处。

        for l in range(self.depth):exec(f'self.node_bias_{l} = torch.nn.Parameter(torch.zeros(width_in[l+1])).requires_grad_(affine_trainable)')exec(f'self.node_scale_{l} = torch.nn.Parameter(torch.ones(width_in[l+1])).requires_grad_(affine_trainable)')exec(f'self.subnode_bias_{l} = torch.nn.Parameter(torch.zeros(width_out[l+1])).requires_grad_(affine_trainable)')exec(f'self.subnode_scale_{l} = torch.nn.Parameter(torch.ones(width_out[l+1])).requires_grad_(affine_trainable)')exec(f'self.node_bias.append(self.node_bias_{l})')exec(f'self.node_scale.append(self.node_scale_{l})')exec(f'self.subnode_bias.append(self.subnode_bias_{l})')exec(f'self.subnode_scale.append(self.subnode_scale_{l})')
  • 通过循环,它为模型的每一层创建了节点偏置、节点缩放、子节点偏置和子节点缩放参数,并将这些参数存储在类的属性中,以便后续使用。
  • 通过 affine_trainable 参数来控制哪些参数是可训练的。
        self.act_fun = nn.ModuleList(self.act_fun)self.grid = gridself.k = kself.base_fun = base_fun

这几个基础的设置就不解释了。

        ### initializing the symbolic front ###self.symbolic_fun = []for l in range(self.depth):sb_batch = Symbolic_KANLayer(in_dim=width_in[l], out_dim=width_out[l+1])self.symbolic_fun.append(sb_batch)

刚刚创建了B样条KAN层,现在创建符号KAN层。

        self.symbolic_fun = nn.ModuleList(self.symbolic_fun)self.symbolic_enabled = symbolic_enabledself.affine_trainable = affine_trainableself.sp_trainable = sp_trainableself.sb_trainable = sb_trainable
  • 将符号层加入列表
  • 设置符号层是否可用
  • 设置符号层线性函数的四个参数是否可训练
  • 设置激活函数中的参数 eq?w_%7Bs%7D 是否可训练,分为了sp和sb两种,sp为B样条KAN层的,sb为符号KAN层的
        self.save_act = save_actself.node_scores = Noneself.edge_scores = Noneself.subnode_scores = Noneself.cache_data = Noneself.acts = Noneself.auto_save = auto_saveself.state_id = 0self.ckpt_path = ckpt_pathself.round = round

一些中间结果的存储变量和保存操作设置,保存的具体操作如下:

        if auto_save:if first_init:if not os.path.exists(ckpt_path):# Create the directoryos.makedirs(ckpt_path)print(f"checkpoint directory created: {ckpt_path}")print('saving model version 0.0')history_path = self.ckpt_path+'/history.txt'with open(history_path, 'w') as file:file.write(f'### Round {self.round} ###' + '\n')file.write('init => 0.0' + '\n')self.saveckpt(path=self.ckpt_path+'/'+'0.0')else:self.state_id = state_id

我们在hellokan中就见识过,模型在训练过程中会保存中间数据、状态和历史信息等内容

        self.input_id = torch.arange(self.width_in[0],)
  • 给输入节点编号 ,从0开始
        self.device = deviceself.to(device)def to(self, device):'''move the model to deviceArgs:-----device : str or deviceReturns:--------selfExample------->>> from kan import *>>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')>>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)>>> model.to(device)'''super(MultKAN, self).to(device)self.device = devicefor kanlayer in self.act_fun:kanlayer.to(device)for symbolic_kanlayer in self.symbolic_fun:symbolic_kanlayer.to(device)return self
  •  选择计算设备

测试:

from kan import *
torch.set_default_dtype(torch.float64)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)model = KAN(width=[2,[5,3],[5,1],3], mult_arity=[0,[2,3,4],[2],0],grid=3, k=3, seed=42, device=device)
model.input_id

cuda
checkpoint directory created: ./model
saving model version 0.0

tensor([0, 1])

2.3 节点数计算

2.3.1 width_in 

    @propertydef width_in(self):'''The number of input nodes for each layer'''width = self.widthwidth_in = [width[l][0]+width[l][1] for l in range(len(width))]return width_in

这段代码定义了一个属性 width_in ,它的作用是计算并返回模型每一层的输入节点数量。

  •  首先,获取了模型的宽度信息 width 。
  • 然后,通过列表推导式计算每一层输入节点的数量,计算方式是将每一层的总和维度 width[l][0] 和乘法操作维度 width[l][1] 相加。
  • 最后,返回计算得到的输入节点数量列表。

所以每层节点数=设置的节点数+乘法操作次数

测试:

width=[2,[5,3],[5,1],3],mult_arity = [[0],[2,3,4],[2],[0]]

print(model.width)
model.width_in

[[2, 0], [5, 3], [5, 1], [3, 0]]

 [2, 8, 6, 3]

2.3.2 width_out

    @propertydef width_out(self):'''The number of output subnodes for each layer'''width = self.widthif self.mult_homo == True:width_out = [width[l][0]+self.mult_arity*width[l][1] for l in range(len(width))]else:width_out = [width[l][0]+int(np.sum(self.mult_arity[l])) for l in range(len(width))]return width_out

这段代码定义了一个属性 width_out,其目的是计算并返回模型每一层的输出子节点数量。

  • 首先,获取了模型的宽度信息 width。然后根据 self.mult_homo 的值来决定计算输出节点数量的方式。
  • 如果 self.mult_homo 为 True,则使用列表推导式计算每一层的输出节点数量。计算方式是将每一层的总和维度 width[l][0] 与乘法操作维度 width[l][1] 的结果乘以 mult_arity 的值相加。mult_arity 是一个数组,表示每一层的乘法操作的幅度。
  • 如果 self.mult_homo 为 False,则使用列表推导式计算每一层的输出节点数量。计算方式是将每一层的总和维度 width[l][0] 与 mult_arity[l] 的元素之和相加。mult_arity[l] 是一个数组,表示每一层的乘法操作的幅度。
  • 最后,返回计算得到的输出节点数量列表。

测试:

width=[2,[5,3],[5,1],3],mult_arity = [[0],[2,3,4],[2],[0]]

print(model.width)
model.width_out

[[2, 0], [5, 3], [5, 1], [3, 0]]

[2, 14, 7, 3]

2.3.3 n_sum

    @propertydef n_sum(self):'''The number of addition nodes for each layer'''width = self.widthn_sum = [width[l][0] for l in range(1,len(width)-1)]return n_sum

这段代码定义了一个属性 n_sum ,用于计算并返回除了第一层和最后一层之外,每一层的总和维度 width[l][0] 所组成的列表。

首先,获取了模型的宽度信息 width 。然后通过列表推导式,从第二层到倒数第二层,提取出每一层的 width[l][0] ,并将这些值组成一个新的列表 n_sum ,最后返回这个列表。

测试:

width=[2,[5,3],[5,1],3],mult_arity = [[0],[2,3,4],[2],[0]]

print(model.width)
model.n_sum

[[2, 0], [5, 3], [5, 1], [3, 0]]
[5, 5]

2.3.4 n_mult

    @propertydef n_mult(self):'''The number of multiplication nodes for each layer'''width = self.widthn_mult = [width[l][1] for l in range(1,len(width)-1)]return n_mult

这段代码定义了一个属性 n_mult ,用于计算并返回除了第一层和最后一层之外,每一层的乘法节点数量。这里 width 是一个包含多层宽度信息的数据结构,每一层的信息以列表的形式存储,其中 width[l][1] 表示第 l 层的乘法节点数量。

通过列表推导式,代码遍历从第二层到倒数第二层的所有层,提取每一层的乘法节点数量,并将这些数量组成一个新的列表 n_mult 。最后,这个列表被返回给调用者。

测试:

width=[2,[5,3],[5,1],3],mult_arity = [[0],[2,3,4],[2],[0]]

print(model.width)
model.n_mult

[[2, 0], [5, 3], [5, 1], [3, 0]]

[3, 1]

2.3.5 feature_score

    @propertydef feature_score(self):'''attribution scores for inputs'''self.attribute()if self.node_scores == None:return Noneelse:return self.node_scores[0]

这段代码定义了一个名为 feature_score 的属性。其功能是计算输入的归因分数。

首先调用了 self.attribute() 方法。然后判断 self.node_scores 是否为 None ,如果是,则直接返回 None ;如果不是,则返回 self.node_scores 中的第一个元素。

这意味着只有在 self.node_scores 不为空的情况下,才会返回其第一个元素作为特征分数。

2.4 前向传播 forward

这个前向传播有点诡异,总感觉跟论文中的对不上,这次我们一边解释一边测试!

先来个简单的:width=[2,5,5,3],mult_arity = 2

from kan import *
torch.set_default_dtype(torch.float64)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)model = KAN(width=[2,5,5,3], mult_arity=2,grid=3, k=3, seed=42, device=device)
model.input_id

cuda
checkpoint directory created: ./model
saving model version 0.0
tensor([0, 1])

 测试数据:

x = torch.tensor([[1,2],[3,4],[5,6],[7,8],[9,10]]).float()
x = x.to(device)

 0. 方法定义及注释

    def forward(self, x, singularity_avoiding=False, y_th=10.):'''forward passArgs:-----x : 2D torch.tensorinputssingularity_avoiding : boolwhether to avoid singularity for the symbolic branchy_th : floatthe threshold for singularityReturns:--------NoneExample1-------->>> from kan import *>>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)>>> x = torch.rand(100,2)>>> model(x).shapeExample2-------->>> from kan import *>>> model = KAN(width=[1,1], grid=5, k=3, seed=0)>>> x = torch.tensor([[1],[-0.01]])>>> model.fix_symbolic(0,0,0,'log',fit_params_bool=False)>>> print(model(x))>>> print(model(x, singularity_avoiding=True))>>> print(model(x, singularity_avoiding=True, y_th=1.))'''

参数说明:

  • x: 2D torch.tensor,输入数据。
  • singularity_avoiding: bool,默认为 False。如果为 True,则在符号分支中避免奇异点。
  • y_th: float,默认为 10.。用于判断是否避免奇异点的阈值。

返回值:

  • None:方法执行后不返回任何值

1. 初始化阶段

        x = x[:,self.input_id.long()]assert x.shape[1] == self.width_in[0]# cache dataself.cache_data = xself.acts = []  # shape ([batch, n0], [batch, n1], ..., [batch, n_L])self.acts_premult = []self.spline_preacts = []self.spline_postsplines = []self.spline_postacts = []self.acts_scale = []self.acts_scale_spline = []self.subnode_actscale = []self.edge_actscale = []# self.neurons_scale = []self.acts.append(x)  # acts shape: (batch, width[l])
  • 数据选择与验证

    • 选择输入数据 x 的特定列,并验证其形状是否符合模型的输入宽度要求。
    • 缓存输入数据 x
  • 初始化变量

    • 初始化用于存储不同层激活、尺度因子等的列表。

2. 前向传播循环

        for l in range(self.depth):
  • 循环遍历模型中的每一层,其中 self.depth 是模型的层数。
            x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)#print(preacts, postacts_numerical, postspline)
  • 使用第 l 层的激活函数 act_fun[l] 对输入 x 进行处理。
  • 这里的激活函数是B样条KAN层的激活函数,详情见KANLayer
  • 处理结果包括数值分支的输出 x_numerical、预激活输出 preacts、后激活输出 postacts_numerical 和样条函数的输出 postspline。(对应的是y, preacts, postacts, postspline)
            if self.symbolic_enabled == True:x_symbolic, postacts_symbolic = self.symbolic_fun[l](x, singularity_avoiding=singularity_avoiding, y_th=y_th)else:x_symbolic = 0.postacts_symbolic = 0.
  • 可使用符号KAN层时,同样进行计算
            x = x_numerical + x_symbolic

这里要注意了,作者将两种层的计算结果相加了!也就是把B样条和线性函数同时叠加使用!

            # subnode affine transformx = self.subnode_scale[l][None,:] * x + self.subnode_bias[l][None,:]
  •  对激活函数的计算结果进行缩放,并增加偏置常数b

对以上这一部分内容做测试:

x = x[:,model.input_id.long()]
assert x.shape[1] == model.width_in[0]for l in range(model.depth):x_numerical, preacts, postacts_numerical, postspline = model.act_fun[l](x)#print(preacts, postacts_numerical, postspline)if model.symbolic_enabled == True:x_symbolic, postacts_symbolic = model.symbolic_fun[l](x, singularity_avoiding=False, y_th=10)else:x_symbolic = 0.postacts_symbolic = 0.x = x_numerical + x_symbolicx = model.subnode_scale[l][None,:] * x + model.subnode_bias[l][None,:]print(x)print(x.shape)

 tensor([[ 1.2935, -0.7047, -1.1071,  0.1673,  0.7162],
        [ 3.6752, -2.1692, -2.5181,  0.0475,  2.4007],
        [ 5.9759, -3.6323, -3.8994, -0.1110,  4.0748],
        [ 8.2045, -5.0443, -5.2469, -0.2554,  5.6880],
        [10.4148, -6.4431, -6.5866, -0.3956,  7.2852]], device='cuda:0',
       grad_fn=<AddBackward0>)
torch.Size([5, 5])
tensor([[-0.1832,  0.2447,  0.2546,  0.0981, -0.0997],
        [-0.8389,  1.2821,  1.2113,  0.4335, -0.0337],
        [-1.3981,  2.3204,  2.1626,  0.7829, -0.0287],
        [-1.9293,  3.2605,  3.0434,  1.1072, -0.0466],
        [-2.4603,  4.1674,  3.9076,  1.4242, -0.0795]], device='cuda:0',
       grad_fn=<AddBackward0>)
torch.Size([5, 5])
tensor([[ 0.0064,  0.0481, -0.1441],
        [ 0.4862,  0.2570, -0.7088],
        [ 1.0812,  0.6035, -1.4443],
        [ 1.6532,  0.8613, -2.0778],
        [ 2.1898,  1.0638, -2.6628]], device='cuda:0', grad_fn=<AddBackward0>)
torch.Size([5, 3])

 所有中间结果的形状都没有问题。

            if self.save_act:# save subnode_scaleself.subnode_actscale.append(torch.std(x, dim=0).detach())if self.save_act:postacts = postacts_numerical + postacts_symbolic# self.neurons_scale.append(torch.mean(torch.abs(x), dim=0))#grid_reshape = self.act_fun[l].grid.reshape(self.width_out[l + 1], self.width_in[l], -1)input_range = torch.std(preacts, dim=0) + 0.1output_range_spline = torch.std(postacts_numerical, dim=0) # for training, only penalize the spline partoutput_range = torch.std(postacts, dim=0) # for visualization, include the contribution from both spline + symbolic# save edge_scaleself.edge_actscale.append(output_range)self.acts_scale.append((output_range / input_range).detach())self.acts_scale_spline.append(output_range_spline / input_range)self.spline_preacts.append(preacts.detach())self.spline_postacts.append(postacts.detach())self.spline_postsplines.append(postspline.detach())self.acts_premult.append(x.detach())

 如果启用了保存激活函数尺度因子的选项,则计算并保存以下内容:

  • 子节点尺度因子(标准差)。
  • 边尺度因子(输出范围)。
  • 激活函数输出的尺度因子(输出范围与输入范围的比例)。
  • 样条部分的尺度因子。
  • 预激活输出、后激活输出和样条函数输出的副本。

但是我很好奇,这不是self.spline_postacts嘛,但是存的是postacts = postacts_numerical + postacts_symbolic,保存样条的激活输出为什么不只保存postacts_numerical。

还有就是都是判断 save_act,为啥用两个if。有时候就挺不能理解的

接下来介绍的这个东西,非常重要!它在基础节点的基础上引入了乘法操作。并且分为同质和非同质两种。

            # multiplicationdim_sum = self.width[l+1][0]dim_mult = self.width[l+1][1]
  • 获取下一次节点数以及乘法操作次数
  • self.width[l+1][0]是下一层的节点数
  • self.width[l+1][1]是乘法操作次数

对于上面的例子,有

x.shape: torch.Size([5, 5])
dim_sum: 5
dim_mult: 0
x.shape: torch.Size([5, 5])
dim_sum: 5
dim_mult: 0
x.shape: torch.Size([5, 3])
dim_sum: 3
dim_mult: 0

            if self.mult_homo == True:for i in range(self.mult_arity-1):if i == 0:x_mult = x[:,dim_sum::self.mult_arity] * x[:,dim_sum+1::self.mult_arity]else:x_mult = x_mult * x[:,dim_sum+i+1::self.mult_arity]
  • 当本层乘法参数都相同,则进行矩阵运算,即处理同质(homogeneous)乘法操作:
    • 在第一次循环(i == 0)中:
      • x_mult = x[:,dim_sum::self.mult_arity] * x[:,dim_sum+1::self.mult_arity]:对 x 的特定部分进行逐元素乘法。这里 x[:,dim_sum::self.mult_arity] 表示从 dim_sum 开始,每隔 self.mult_arity 个元素取一个元素,形成一个新的张量。同理 x[:,dim_sum+1::self.mult_arity] 表示从 dim_sum+1 开始取元素。这两个张量逐元素相乘得到 x_mult
    • 在后续的循环中(i != 0):
      • x_mult = x_mult * x[:,dim_sum+i+1::self.mult_arity]:将上一次乘法的结果与 x 的另一部分相乘。

对于我们width=[2,5,5,3],mult_arity = 2这个例子,有model.mult_homo == True,但结果如下:

tensor([], device='cuda:0', size=(5, 0), grad_fn=<MulBackward0>)
x_mult.shape: torch.Size([5, 0])
tensor([], device='cuda:0', size=(5, 0), grad_fn=<MulBackward0>)
x_mult.shape: torch.Size([5, 0])
tensor([], device='cuda:0', size=(5, 0), grad_fn=<MulBackward0>)
x_mult.shape: torch.Size([5, 0])

 由于dim_mult = 0,所以不进行乘法运算,代码中表现为dim_sum超出index,所以dim_sum::model.mult_arity都为0,自然乘积也为0。

测试升级:

设置width=[2,[5,2],[5,3],3], mult_arity=3,这是一个同质运算,由第一层向第二层传递时,会做乘法运算,次数为mult_arity-1=2,而乘法运算结果维度为dim_mult,然后与原始的dim_sum维度拼接,参数设置 width=[2,[5,1],[5,3],3], mult_arity=3,拼接操作:

            if self.width[l+1][1] > 0:x = torch.cat([x[:,:dim_sum], x_mult], dim=1)
  •  将x中未参与乘法计算的部分与乘法计算结果进行拼接,恢复原始张量形状 

x.shape: torch.Size([5, 8])
x_mult.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 1])
x.shape: torch.Size([5, 6])

 

x.shape: torch.Size([5, 14])
x_mult.shape: torch.Size([5, 3])
x_mult.shape: torch.Size([5, 3])
x.shape: torch.Size([5, 8])

 

x.shape: torch.Size([5, 3])
x_mult.shape: torch.Size([5, 0])
x_mult.shape: torch.Size([5, 0])
x.shape: torch.Size([5, 3])

 测试再次升级:

我用数据展示第二层的计算:

x = torch.tensor([[1,2,3,4,5,6,7,8,9,10,11,12,13,14]])
dim_sum = 5
dim_mult = 3
mult_arity = 3
if model.mult_homo == True:for i in range(2):print(f"第{i+1}次乘法:")if i == 0:print(x[:,dim_sum::mult_arity])print(x[:,dim_sum+1::mult_arity])x_mult = x[:,dim_sum::mult_arity] * x[:,dim_sum+1::mult_arity]else:print(x_mult)print(x[:,dim_sum+i+1::mult_arity])x_mult = x_mult * x[:,dim_sum+i+1::mult_arity]print(x_mult)if dim_mult > 0:x = torch.cat([x[:,:dim_sum], x_mult], dim=1)print(x)
print("x.shape:",x.shape)
print()

 第1次乘法:
tensor([[ 6,  9, 12]])
tensor([[ 7, 10, 13]])
tensor([[ 42,  90, 156]])


第2次乘法:
tensor([[ 42,  90, 156]])
tensor([[ 8, 11, 14]])
tensor([[ 336,  990, 2184]])


tensor([[   1,    2,    3,    4,    5,  336,  990, 2184]])
x.shape: torch.Size([1, 8])

 这下就完全理解它的乘法是如何运算的了。同质运算使用了矩阵运算以加快运算速度,这建立在mult_arity为常数的情况下,而当mult_arity的元素为列表时,只能进行遍历运算,如下:

            else:for j in range(dim_mult):acml_id = dim_sum + np.sum(self.mult_arity[l+1][:j])for i in range(self.mult_arity[l+1][j]-1):if i == 0:x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]]else:x_mult_j = x_mult_j * x[:,[acml_id+i+1]]if j == 0:x_mult = x_mult_jelse:x_mult = torch.cat([x_mult, x_mult_j], dim=1)
  • 当本层乘法参数不相同,则进行遍历运算,即处理非同质(non-homogeneous)乘法操作:
    • for j in range(dim_mult):循环遍历 dim_mult 次,dim_mult 表示乘法操作的次数。
    • 在每次循环中,计算 acml_id
      • acml_id = dim_sum + np.sum(self.mult_arity[l+1][:j]):计算当前乘法操作的起始索引。
    • 然后对每个乘法操作:
      • for i in range(self.mult_arity[l+1][j]-1)::循环遍历当前维度的乘法操作次数。
      • 在第一次循环(i == 0)中:
        • x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]]:对 x 的特定部分进行逐元素乘法。
      • 在后续的循环中(i != 0):
        • x_mult_j = x_mult_j * x[:,[acml_id+i+1]]:将上一次乘法的结果与 x 的另一部分相乘。
    • 如果是第一个乘法操作(j == 0):
      • x_mult = x_mult_j:将第一个乘法操作的结果赋值给 x_mult
    • 如果不是第一个乘法操作:
      • x_mult = torch.cat([x_mult, x_mult_j], dim=1):将当前乘法操作的结果与之前的结果在最后一个维度上连接。

测试:

参数设置 width=[2,[5,1],[5,3],3], mult_arity=[[0],[2],[2,3,4],[0]]

x = torch.tensor([[1,2],[3,4],[5,6],[7,8],[9,10]]).float()
x = x.to(device)x = x[:,model.input_id.long()]
assert x.shape[1] == model.width_in[0]for l in range(model.depth):x_numerical, preacts, postacts_numerical, postspline = model.act_fun[l](x)#print(preacts, postacts_numerical, postspline)if model.symbolic_enabled == True:x_symbolic, postacts_symbolic = model.symbolic_fun[l](x, singularity_avoiding=False, y_th=10)else:x_symbolic = 0.postacts_symbolic = 0.x = x_numerical + x_symbolicx = model.subnode_scale[l][None,:] * x + model.subnode_bias[l][None,:]#print(x)print("x.shape:",x.shape)# multiplicationdim_sum = model.width[l+1][0]dim_mult = model.width[l+1][1]#print("dim_sum:",dim_sum)#print("dim_mult:",dim_mult)if model.mult_homo == False:for j in range(dim_mult):acml_id = dim_sum + np.sum(model.mult_arity[l+1][:j])for i in range(model.mult_arity[l+1][j]-1):if i == 0:x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]]else:x_mult_j = x_mult_j * x[:,[acml_id+i+1]]print("x_mult_j.shape:",x_mult_j.shape )if j == 0:x_mult = x_mult_jelse:x_mult = torch.cat([x_mult, x_mult_j], dim=1)print("x_mult.shape:",x_mult.shape)if model.width[l+1][1] > 0:x = torch.cat([x[:,:dim_sum], x_mult], dim=1)

x.shape: torch.Size([5, 7])
x_mult_j.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 1])
x.shape: torch.Size([5, 6])

 

x.shape: torch.Size([5, 14])
x_mult_j.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 1])
x_mult_j.shape: torch.Size([5, 1])
x_mult_j.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 2])
x_mult_j.shape: torch.Size([5, 1])
x_mult_j.shape: torch.Size([5, 1])
x_mult_j.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 3])
x.shape: torch.Size([5, 8])

 

x.shape: torch.Size([5, 3])
x.shape: torch.Size([5, 3])

 来逐一分析:

  • 第一层到第二层计算,经过B样条KAN层和符号KAN层,x形状为[batch_size,dim_sum+sum(mult_arity[l+1])],其中sum(mult_arity[l+1])=dim_mult*mult_arity[l+1],因为mult_arity[l+1]只有一个元素。然后进行了一次乘法运算,并将结果拼接在x[:,:dim_sum]后面
  • 第二层到第三次计算:经过B样条KAN层和符号KAN层,x形状为[batch_size,dim_sum+sum(mult_arity[l+1])],其中sum(mult_arity[l+1])=np.sum(model.mult_arity[l+1][:j]),对于mult_arity[l+1]列表中的每一个元素,都执行其数值减一的乘法运算,运算结果x_mult_j的形状为[batch_size,1],最终获得的x_mult都是由x_mult_j拼接来的,最后将x_mult拼接在x[:,:dim_sum]后面。
  • 第三层到第四层同理。

使用数据展示第二层的计算:

x = torch.tensor([[1,2,3,4,5,6,7,8,9,10,11,12,13,14]])
dim_sum = 5
dim_mult = 3
mult_arity = [2,3,4]print(x)if model.mult_homo == False:for j in range(dim_mult):print(f"第{j+1}次运算:")acml_id = dim_sum + np.sum(mult_arity[:j])for i in range(mult_arity[j]-1):if i == 0:print(x[:,[acml_id]])print(x[:,[acml_id+1]])x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]]else:print(x_mult_j)print(x[:,[acml_id+i+1]])x_mult_j = x_mult_j * x[:,[acml_id+i+1]]print("x_mult_j:",x_mult_j)if j == 0:x_mult = x_mult_jelse:x_mult = torch.cat([x_mult, x_mult_j], dim=1)print("x_mult:",x_mult)if dim_mult > 0:x = torch.cat([x[:,:dim_sum], x_mult], dim=1)print(x)
print("x.shape:",x.shape)
print()

tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14]])
第1次运算:
tensor([[6]])
tensor([[7]])
x_mult_j: tensor([[42]])
x_mult: tensor([[42]])


第2次运算:
tensor([[8]])
tensor([[9]])
x_mult_j: tensor([[72]])
tensor([[72]])
tensor([[10]])
x_mult_j: tensor([[720]])
x_mult: tensor([[ 42, 720]])


第3次运算:
tensor([[11]])
tensor([[12]])
x_mult_j: tensor([[132]])
tensor([[132]])
tensor([[13]])
x_mult_j: tensor([[1716]])
tensor([[1716]])
tensor([[14]])
x_mult_j: tensor([[24024]])
x_mult: tensor([[   42,   720, 24024]])


tensor([[    1,     2,     3,     4,     5,    42,   720, 24024]])
x.shape: torch.Size([1, 8])

            # x = x + self.biases[l].weight# node affine transformx = self.node_scale[l][None,:] * x + self.node_bias[l][None,:]self.acts.append(x.detach())return x
  • 对拼接后的x进行缩放并且加上偏置常数 
  • 返回计算结果

我们理一下整个计算思路:

  1. 传入x后,首先检查x的形状,遍历KAN层进行计算:
    1. 再分别使用B样条KAN层和符号KAN层计算出x = x_numerical + x_symbolic
    2. 对x进行缩放处理,并加入偏置常数
    3. 乘法运算
      1. 同质乘法运算:对于dim_sum之外的维度,使用矩阵运算计算出x_mult
      2. 非同质乘法运算:根据mult_arity[l+1]列表一次计算出x_mult_j,拼接成x_mult
    4. 如进行了乘法运算,则将x_mult与x[:,:dim_sum]拼接
    5. 对x进行缩放处理,并加入偏置常数
    6. 返回x

2.5 训练方法 fit

通过对前向传播进行剖析,KAN网络并不像论文中展示的那么简单

  • KAN层包含了B样条层和符号层两种,我们可以设置是否使用符号层,如使用的话,中间x计算结果为两者之和。
  • KAN层节点过渡时引入了乘法操作,包括同质乘法和非同质乘法,在定义的基础维度上进行了扩展,进一步加强了网络的学习能力。

现在我们对MultKAN的fit方法的使用进行详解。

    def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1.,start_grid_update_step=-1, stop_grid_update_step=50, batch=-1,metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n', display_metrics=None):'''trainingArgs:-----dataset : diccontains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label']opt : str"LBFGS" or "Adam"steps : inttraining stepslog : intlogging frequencylamb : floatoverall penalty strengthlamb_l1 : floatl1 penalty strengthlamb_entropy : floatentropy penalty strengthlamb_coef : floatcoefficient magnitude penalty strengthlamb_coefdiff : floatdifference of nearby coefficits (smoothness) penalty strengthupdate_grid : boolIf True, update grid regularly before stop_grid_update_stepgrid_update_num : intthe number of grid updates before stop_grid_update_stepstart_grid_update_step : intno grid updates before this training stepstop_grid_update_step : intno grid updates after this training steploss_fn : functionloss functionlr : floatlearning ratebatch : intbatch size, if -1 then full.save_fig_freq : intsave figure every (save_fig_freq) stepssingularity_avoiding : boolindicate whether to avoid singularity for the symbolic party_th : floatsingularity threshold (anything above the threshold is considered singular and is softened in some ways)reg_metric : strregularization metric. Choose from {'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward'}metrics : a list of metrics (as functions)the metrics to be computed in trainingdisplay_metrics : a list of functionsthe metric to be displayed in tqdm progress barReturns:--------results : dicresults['train_loss'], 1D array of training losses (RMSE)results['test_loss'], 1D array of test losses (RMSE)results['reg'], 1D array of regularizationother metrics specified in metricsExample------->>> from kan import *>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)>>> dataset = create_dataset(f, n_var=2)>>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);>>> model.plot()# Most examples in toturals involve the fit() method. Please check them for useness.'''

参数说明:

  1. dataset (dic): 包含训练集和测试集的数据字典,通常包括输入数据(train_inputtest_input)和标签数据(train_labeltest_label)。

  2. opt (str): 选择的优化器,可以是 "LBFGS"(L-BFGS)或 "Adam"。

  3. steps (int): 训练的总步骤数。

  4. log (int): 日志输出的频率,即每多少步骤输出一次日志。

  5. lamb (float): 总体正则化强度,用于控制模型复杂度。

  6. lamb_l1 (float): L1 正则化强度,用于惩罚模型参数的绝对值。

  7. lamb_entropy (float): 用于惩罚模型熵的强度,有助于防止过拟合。

  8. lamb_coef (float): 模型系数的大小惩罚强度。

  9. lamb_coefdiff (float): 邻近系数之间的差异惩罚强度,用于增加模型的平滑性。

  10. update_grid (bool): 如果为 True,则在训练步骤达到 stop_grid_update_step 之前定期更新网格。

  11. grid_update_num (int): 在 stop_grid_update_step 之前更新网格的次数。

  12. start_grid_update_step (int): 在这个步骤之前不进行网格更新。

  13. stop_grid_update_step (int): 这个步骤之后不进行网格更新。

  14. loss_fn (function): 自定义损失函数,用于计算模型的损失。

  15. lr (float): 学习率,决定每次更新参数时的步长。

  16. batch (int): 批处理大小,如果为 -1,则使用完整数据集。

  17. save_fig_freq (int): 每多少步骤保存一次训练结果的图形。

  18. singularity_avoiding (bool): 如果为 True,则在符号部分避免奇异点。

  19. y_th (float): 奇异点阈值,高于此值的任何值都将被视为奇异点。

  20. reg_metric (str): 用于计算正则化的度量标准,可以选择不同的选项如 edge_forward_spline_n 等。

  21. metrics (list of functions): 计算并返回的自定义度量列表。

  22. display_metrics (list of functions): 在训练进度条中显示的度量列表。

返回值:

  • results (dic): 包含训练过程中的关键信息的字典,包括:
    • train_loss: 训练集上的损失(通常为 RMSE)。
    • test_loss: 测试集上的损失(通常为 RMSE)。
    • reg: 正则化项的值。
    • 其他用户指定的度量。

测试1:

from kan import *
model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
model.plot()

checkpoint directory created: ./model
saving model version 0.0
| train_loss: 1.91e-02 | test_loss: 1.97e-02 | reg: 1.38e+01 | : 100%|█| 20/20 [00:07<00:00,  2.66it
saving model version 0.1

7bdaad0cace146f7b795d0de96bbae94.png

 测试2:

from kan import *
model = KAN(width=[2,[5,3],3], mult_arity=3, grid=5, k=3, noise_scale=0.3, seed=2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
model.plot()

checkpoint directory created: ./model
saving model version 0.0
| train_loss: 2.44e-02 | test_loss: 2.81e-02 | reg: 2.67e+01 | : 100%|█| 20/20 [00:09<00:00,  2.15it
saving model version 0.1

b03c0e119ad04e4eb44dc123370a436c.png

这个图包含了3个乘法节点。

三、总结

今天内容主要包括MultKAN网络的初始化、正向传播方法实现、训练方法参数说明。MultKAN网络正向传播有两个特点:

  1. 传播时可以同时使用KANLayer和Symbolic_KANLayer,以叠加的形式计算中间结果
  2. KAN节点的连接既有加法连接也有乘法连接,我们可以自定义乘法运算的方式(同质或非同质)

在上文中,我用数据直观展示了mult的计算过程,实际上只是连续的列相乘,因此在我看来,MultKAN的mult节点运算还有一定的优化空间,除了改善单一控制变量self.mult_homo,将其扩展为列表,还可以用numpy库实现连续列相乘的算法,这些尝试我打算放在实际应用中进行。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/53908.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

『功能项目』怪物的有限状态机【42】

本章项目成果展示 我们打开上一篇41项目优化 - 框架加载资源的项目&#xff0c; 本章要做的事情是按照框架的思想构建项目并完成怪物的自动巡逻状态&#xff0c;当主角靠近怪物时&#xff0c;怪物会朝向主角释放技能 首先新建脚本&#xff1a;BossCtrl.cs (通常把xxxCtrl.cs脚…

SpringBoot2:请求处理原理分析-利用内容协商功能实现接口的两种数据格式(JSON、XML)

文章目录 一、功能说明二、案例实现1、基于请求头实现2、基于请求参数实现 一、功能说明 我们知道&#xff0c;用ResponseBody注解标注的接口&#xff0c;默认返回给页面的是json数据。 其实&#xff0c;也可以返回xml结构的数据给页面。 这一篇就来实现一下这个小功能。 二、…

深入理解数据分析的使用流程:从数据准备到洞察挖掘

数据分析是企业和技术团队实现价值的核心。 5 秒内你能否让数据帮你做出决策&#xff1f; 通过本文&#xff0c;我们将深入探讨如何将原始数据转化为有意义的洞察&#xff0c;帮助你快速掌握数据分析的关键流程。 目录 数据分析的五个核心步骤1. 数据获取常用数据获取方式 2. 数…

【CS110L】Rust语言 Lecture3-4 笔记

文章目录 第三讲 所有权:移动与借用&例1例2例3 错误处理&#xff08;开头&#xff09;为什么空指针如此危险&#xff0c;我们能做什么以应对&#xff1f;— 引出Optionis_none()函数unwrap_or()函数常见用法 第四讲 代码实践:链表Box节点和链表的定义节点和链表的构造函数判…

charls基于夜神模拟器抓取安卓7.0应用程序https请求

charls基于夜神模拟器抓取安卓7.0应用程序https请求 1、安装charls(安装步骤这里就不详细说了)2、下载证书(证书后缀名 xx.pem)3、使用git bash生成证书hash4、上传证书到安卓的系统证书目录下(夜神模拟器方案)5、验证抓包1、安装charls(安装步骤这里就不详细说了) 2、…

【Vue】2

1 Vue 生命周期 Vue生命周期&#xff1a;一个 Vue 实例从 创建 到 销毁 的整个过程 创建(create)阶段&#xff1a;组件实例化时&#xff0c;初始化数据、事件、计算属性等挂载(mount)阶段&#xff1a;将模板渲染并挂载到 DOM 上更新(update)阶段&#xff1a;当数据发生变化时…

数据中台建设(六)—— 数据开发-提取数据价值

数据开发-提取数据价值 数据开发涉及的产品能力主要包括三部分&#xff1a;离线开发、实时开发和算法开发。 离线开发主要包括离线数据的加工、发布、运维管理&#xff0c;以及数据分析、数据探索、在线查询和及时分析相关工作。实时开发主要涉及数据的实时接入和实时处理。算…

网络高级(学习)2024.9.10

目录 一、Modbus简介 1.起源 2.特点 3.应用场景 二、Modbus TCP协议 1.特点 2.协议格式 3.MBAP报文头 4.功能码 5.寄存器 &#xff08;1&#xff09;线圈寄存器&#xff0c;类比为开关量&#xff0c;每一个bit都对应一个信号的开关状态。 &#xff08;2&#xff09…

[项目实战]EOS多节点部署

文章总览&#xff1a;YuanDaiMa2048博客文章总览 EOS多节点部署 &#xff08;一&#xff09;环境设计&#xff08;二&#xff09;节点配置&#xff08;三&#xff09;区块信息同步&#xff08;四&#xff09;启动节点并验证同步EOS单节点的环境如何配置 &#xff08;一&#xf…

第十一周:机器学习

目录 摘要 Abstract 一、字符级的RNN进行名字分类 1、准备数据 2、构造神经网络 3、训练 4、评价结果 5、预测 二、字符级的RNN生成名字 1、准备数据 2、构造神经网络 3、训练 4、网络采样&#xff08;预测&#xff09; 三、batch normalization 1、 feature n…

Bootstrap 警告信息(Alerts)使用介绍

本章将讲解警告&#xff08;Alerts&#xff09;以及 Bootstrap 所提供的用于警告的 class。警告&#xff08;Alerts&#xff09;向用户提供了一种定义消息样式的方式。它们为典型的用户操作提供了上下文信息反馈。 您可以为警告框添加一个可选的关闭按钮。为了创建一个内联的可…

【工具箱】NAND NOR FLASH闪存

随着国内集成电路的发展&#xff0c;特别是存储芯片方面&#xff0c;关于NOR Flash&#xff0c;NAND Flash&#xff0c;SD NAND, eMMC, Raw NAND的资料越来越多了。这里我专门写了这篇文章&#xff1a;1&#xff0c;把常用的存储产品做了分类; 2&#xff0c;把这些产品的特点做…

[Postman]接口自动化测试入门

文章大多用作个人学习分享&#xff0c;如果大家觉得有不足或错漏的地方欢迎评论指出或补充 此文章将完整的走一遍一个web页面的接口测试流程 大致路径为&#xff1a; 创建集合->调用接口登录获取token->保存token->带着token去完成其他接口的自动化测试->断言-&g…

Kafka下载与安装教程(国产化生产环境无联网服务器部署实操)

请放心观看&#xff0c;已在正式环境部署验证&#xff0c;流程无问题&#xff01; 所用系统为国产化麒麟银河 aarch64系统&#xff0c;部署时间2024年9月份&#xff01; [rootecs-xxxxxx-xxx ~]# cat /etc/os-release NAME"Kylin Linux Advanced Server" VERSION&q…

MySQL 查询数据库的数据总量

需求&#xff1a;查看MySQL数据库的数据总量&#xff0c;以MB为单位展示数据库占用的磁盘空间 实践&#xff1a; 登录到MySQL数据库服务器。 选择你想要查看数据总量的数据库&#xff1a; USE shield;运行查询以获取数据库的总大小&#xff1a; SELECT table_schema AS Datab…

go语言后端开发学习(七)——如何在gin框架中集成限流中间件

一.什么是限流 限流又称为流量控制&#xff08;流控&#xff09;&#xff0c;通常是指限制到达系统的并发请求数。 我们生活中也会经常遇到限流的场景&#xff0c;比如&#xff1a;某景区限制每日进入景区的游客数量为8万人&#xff1b;沙河地铁站早高峰通过站外排队逐一放行的…

JAVA毕业设计170—基于Java+Springboot+vue3+小程序的房屋租赁小程序系统(源代码+数据库)

毕设所有选题&#xff1a; https://blog.csdn.net/2303_76227485/article/details/131104075 基于JavaSpringbootvue3小程序的房屋租赁小程序系统(源代码数据库)170 一、系统介绍 本项目前后端分离(可以改为ssm版本)&#xff0c;分为用户、房东、管理员三种角色 1、用户&am…

[000-002-01].第03节:Linux系统下Oracle的安装与使用

2.1.Docker安装Oracle 在CentOS7中使用Docker安装Oracle&#xff1a; 1.安装Docker,详细请参考&#xff1a;https://blog.csdn.net/weixin_43783284/article/details/1211403682.拉取镜像&#xff1a; docker pull registry.cn-hangzhou.aliyuncs.com/helowin/oracle_11g3.下载…

删除有序数组中的重复项(同向指针(快慢指针))

题目&#xff1a; 算法分析&#xff1a; 快慢指针从0出发若快慢指针不相同&#xff0c;快指针替换慢指针&#xff08;即慢指针后一位&#xff09;快指针每次都会增加题目求不重复的元素个数&#xff08;slow 为对应元素索引&#xff0c;故个数为slow1&#xff09; 算法图解…

如何通过编程工具提升工作效率

目录 常用编程工具介绍 工具效率对比 未来发展趋势 结论 在当今软件开发的高效环境中&#xff0c;工具的选择至关重要。无论是编写代码、调试&#xff0c;还是协作开发&#xff0c;合适的编程工具不仅能够简化开发流程&#xff0c;还可以有效地提高生产力。开发者在日常工作…