KAN(Kolmogorov-Arnold Network)的理解 3

系列文章目录

第一部分 KAN的理解——数学背景
第二部分 KAN的理解——网络结构
第三部分 KAN的实践——第一个例程


文章目录

  • 系列文章目录
  • 前言
  • KAN 的第一个例程 get started


前言

这里记录我对于KAN的探索过程,每次会尝试理解解释一部分问题。欢迎大家和我一起讨论。
KAN tutorial

KAN 的第一个例程 get started

以下内容包含对于代码的理解,对于KAN训练过程的理解和代码的解释。并且包含代码的结果。

  1. 对于KAN进行初始化。
from kan import *
# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)

从上面的代码可以看出,输入两维,说明要拟合的数据有两个输入变量,hidden neurons5个说明是全连接网络,还没有进行剪枝。

gird intervel表示用于拟合的样条函数的一组离散点,这些点用于分段构造样条函数。网格设定的约密集对于拟合的函数精度越高,想要提高网络的拟合能力,一般会增加grid interval的数目,在论文中称为grid extension。

这里的k是指一次样条、二次样条等这里的次数。表示在每个区间内拟合函数时,使用的是多少次数的多项式表示。

seed为随机数种子,通过设置随机数种子seed=0,模型的初始化(如权重初始化)和任何涉及随机性的过程都会产生相同的结果。

  1. 创建数据集,用于作为训练的输入
# create dataset f(x,y) = exp(sin(pi*x)+y^2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
dataset['train_input'].shape, dataset['train_label'].shape

从输出和函数定义来看,默认KAN的train number和test number都是1000

create_dataset函数的功能为生成一系列的数据字典,包括train_input,train_label,test_input,test_label

第一行lambda函数用于定义匿名函数,接收二维函数x为输入,并返回一个新张量f,为其仅进行特定的数学运算并返回结果

  1. 绘制初始化结果
# plot KAN at initialization
model(dataset['train_input'][:20]);
model.plot(beta=100,sample=True)

额外提一句,在做初始化的时候,这里的有一些默认参数没给出来。
在初始化时,已经生成了每个节点的被学习的weight函数曲线的可视化,且被保存在./figures下,在初始化时添加了noise,所以每个节点的曲线形状不同,且在定义模型时还有supervised mode和unsupervised mode可以选择。

这部分代码的功能主要是,在初始化网络时给出了初始化时的可视化。结果如下:
在这里插入图片描述

  1. 模型训练并设置对应的参数
# train the model
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.);

一些参数:
dataset:输入的训练数据
opt:优化算法选择,有LBFGS和Adam算法可供选择,分别问基于二阶导数的算法和基于一阶导数的优化算法
step:训练步数
lamb:控制整体正则化项的强度,能够增强训练的稀疏性,保留有效项
lamb_entropy:控制熵正则化项的强度,能有效减少激活函数的数量,避免出现相同或非常相似的函数

从代码的内容上看,在训练中,已经在进行有效项的保留,重复项的去除。
1000的数据量大概要处理11s

画出此时的第一次训练后的图,发现被判定为不重要的项的透明度增强了许多,在图上显示表示为不重要的部分。

结果如下:
在这里插入图片描述

  1. 剪枝
# model.prune(mode='manual',active_neurons_id=[[3],[2]] )
model.prune()
model.plot(mask=False)

做一些剪枝,直接减掉一些不重要的node。prune的原则是查看每个node的入边和出边,
如果某个节点所连接的入边和出边的属于不重要的边,那么这些边可以被剪枝。
这里的默认参数是自动剪枝,但是实际上也可以选择手动剪枝,确要保留的节点。

  1. 再剪枝
model = model.prune()
model(dataset['train_input'][:20])
model.plot(sample=True)

再剪枝,得到更小的模型。这里的dataset[‘train_input’]应该是用来测试目前的训练结果的。结果如下:
在这里插入图片描述

  1. 再训练
model.train(dataset, opt="LBFGS", steps=50);

现在得到的结果是去掉了一些node的结果,在更少的nodes被保留的情况下,继续进行训练

从训练的结果可以结案到现在的精确度变高了,可能是因为减少了node,保留了可信度更强的node

  1. 再看一遍训练结果。
model.plot()

结果如下:
在这里插入图片描述

  1. 确定要fix的项
mode = "auto" # "manual"
# 设置mannual会报错if mode == "manual":# manual mode# fix_symbolic()方程下的参数,(layer index,layer index,output neuron index)model.fix_symbolic(0,0,0,'sin');model.fix_symbolic(0,1,0,'x^2');model.fix_symbolic(1,0,0,'exp');
elif mode == "auto":# automatic modelib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']model.auto_symbolic(lib=lib)

结果如下:
在这里插入图片描述

  1. 最后输出数学表达式
model.train(dataset, opt="LBFGS", steps=50);
model.symbolic_formula()[0][0]

这里可能出现的问题是,会多余出一些小项,比如预测了正确的公式但是结尾部分会加上一个很小的数值,或者加上一个值很小的表达式。
结果如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/21855.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

百度/迅雷/夸克,网盘免费加速,已破!

哈喽,各位小伙伴们好,我是给大家带来各类黑科技与前沿资讯的小武。 之前给大家安利了百度网盘及迅雷的加速方法,详细方法及获取参考之前文章: 刚刚!度盘、某雷已破!速度50M/s! 本次主要介绍夸…

simulink基础学习笔记

写在前面 这个笔记是看B站UP 快乐的宇航boy 所出的simulink基础教程系列视频过程中记下来的,写的很粗糙不完整,也不会补。视频教程很细跟着做就行。 lesson1-7节的笔记up有,可以加up的群,里面大佬挺活跃的。 lesson8 for循环 For …

【C++初阶学习】第十二弹——stack和queue的介绍和使用

C语言栈:数据结构——栈(C语言版)-CSDN博客 C语言队列:数据结构——队列(C语言版)-CSDN博客 前言: 在之前学习C语言的时候,我们已经学习过栈与队列,并学习过如何使用C语言来实现栈与队列&…

OCR图片转Excel表格:没结构化的弊端

随着OCR技术的不断发展,将表格图片转为excel已不再是难题,但是,目前市面上的程序还大多处于仅能将图片表格转为普通的excel格式阶段,而不能将其结构化,这样就会产生许多的弊端,具体弊端如下: &l…

数据容器的通用操作、字符串大小比较 总结完毕!

1.数据容器的通用操作 1)五类数据容器是否都支持while循环/for循环 五类数据容器都支持for循环遍历 列表、元组、字符串都支持while循环,集合、字典不支持(无法下标索引) 尽管遍历的形式不同,但都支持遍历操作 2&a…

办公软件 Office 安装教程(亲测有效)

Office 现已更名为 Microsoft 365。习惯还是称作 Office。 1、Office 套装下载 Windows 的样子 这里下载的是最新版本的 O365ProPlus 安装完成后,点击关闭(请先不要打开)。 Mac 的样子 这里下载的是Office for Mac 2019(更多版…

速递FineWeb:一个拥有无限潜力的15T Tokens的开源数据集

大模型技术论文不断,每个月总会新增上千篇。本专栏精选论文重点解读,主题还是围绕着行业实践和工程量产。若在某个环节出现卡点,可以回到大模型必备腔调或者LLM背后的基础模型新阅读。而最新科技(Mamba,xLSTM,KAN)则提…

内核宕机自救

【问题】在测试内核级防篡改时,偶尔会遇到内核宕机的问题 【结论】进入紧急救援模式,将服务进程文件的start注释掉,即可 在Linux系统启动时,内核启动顺序选择界面,进入系统欢迎界面按上下左右键进入GRUB界面&#xff…

欧佩克+同意集体性减产延长,油价能否稳住?

KlipC报道:欧佩克组织同意将延长目前部分减产协议至2025年,以支撑油价。主要成员国把2023年11月宣布的日均220万桶的自愿减产措施延长至今年9月底,将在10月份根据市场情况开始缩减自愿减产规模。 高盛分析师表示,“我们认为这次欧…

python常见数据分析函数

apply DataFrame.apply(func, axis0, broadcastFalse, rawFalse, reduceNone, args(), **kwds) 第一个参数是函数 可以在Series或DataFrame上执行一个函数 支持对行、列或单个值进行处理 import numpy as np import pandas as pdf lambda x: x.max()-x.min()df pd.DataFrame(…

高端、大气、很牛B的免费wordpress模板主题

这是一款专为WordPress打造的极简主义风格主题,以白色和黑色为主色调,搭配红色点缀,营造出一种简洁、专业且具有视觉冲击力的效果。 该主题的设计理念是“简单即美”,旨在帮助用户快速搭建一个美观、易用的网站。它提供了丰富的自…

动态sql set标签 , trim标签

set标签 来看例子 set标案解决了逗号问题(当if条件不满足时,逗号无处安放的问题),我认为set标签可以识别这个问题,并自动忽略这个问题 <update id"update">update employee<set><if test"name!null">name#{name},</if><if te…

HTML基本元素包含HTML表单验证

可将以下代码复制另存为一个HTML文件浏览器打开自己去看看实际使用效果 <!DOCTYPE html> <html> <head> <meta charset"utf-8"><title>测试</title> </head> <body> <h1>很多事</h1> <h1><b&…

四、利用启发式算法进行特定数据集的残差网络结构搜索【框架+源码】

背景&#xff1a;工作之后干的事情跟算法关联甚少&#xff0c;整理下读书期间的负责和参与的work&#xff0c;再熟悉学习下。 边熟悉边整理喽~ CV Tradictional workCV AI based work机械臂视觉抓取项目机器学习全流程 Pipeline训练平台OCR生产线喷码识别三维重建(SfM)ROS机器人…

C++的vector使用优化

我们在上一章说了如何使用这个vector动态数组&#xff0c;这章我们说说如何更好的使用它以及它是如何工作的。当你创建一个vector&#xff0c;然后使用push_back添加元素&#xff0c;当当前的vector的内存不够时&#xff0c;会从内存中的旧位置复制到内存中的新位置&#xff0c…

Spring 之 Lifecycle 及 SmartLifecycle

最近在看Eureka源码&#xff0c;本想快速解决这场没有硝烟的战役&#xff0c;不曾想阻塞性问题一个接一个。为正确理解这个框架&#xff0c;我不得不耐着性子&#xff0c;慢慢梳理这些让人困惑的点。譬如本章要梳理的Lifecycle和SmartLifecycle。它们均为接口&#xff0c;其中后…

mysql的锁(全局锁)

文章目录 mysql按照锁的粒度分类全局锁概念&#xff1a;全局锁使用场景&#xff1a;全局锁备份案例&#xff1a; mysql按照锁的粒度分类 全局锁 概念&#xff1a; 全局锁就是对整个数据库实例加锁。MySQL 提供了一个加全局读锁的方法&#xff0c;命令是: Flush tables with…

排序算法——归并排序以及非递归实现

一、归并排序思想 归并排序&#xff08;MERGE-SORT&#xff09;是建立在归并操作上的一种有效的排序算法,该算法是采用分治法&#xff08;Divide andConquer&#xff09;的一个非常典型的应用。将已有序的子序列合并&#xff0c;得到完全有序的序列&#xff1b;即先使每个子序列…

重新ysyx

一、克隆仓库 1.创建ssh key ssh-keygen -t rsa cd ~/.ssh ls 查看里面是否有id_rsa id_rsa.pub ssh-keygen -t rsa -C "xiantong15834753336outlook.com" cat id_rsa.pub***********查看里面的内容&#xff0c;复制到下图中绿色的按钮 git init ssh -T g…

Marin说PCB之Max parallel知多少?

今天是个阳光明媚&#xff0c;万里乌云的好日子。小编我一如既往地到家打开电脑准备看腾讯视频的五十公里桃花坞的第四季&#xff0c;在看到汪苏泷汪台说650电台要解散的时候小编我差点也哭了。650电台之于桃花坞就像乐队的鼓手一样&#xff0c;都是一个团队的灵感啊&#xff0…