第15周:注意力汇聚:Nadaraya-Watson 核回归

注意力汇聚:Nadaraya-Watson 核回归

Nadaraya-Watson 核回归是一个经典的注意力机制模型,它展示了如何通过注意力权重来对输入数据进行加权平均。以下是该内容的核心总结:

关键概念

  1. 注意力机制框架:由查询(自主提示)、键(非自主提示)和值(感官输入)组成,通过查询和键的交互形成注意力权重,然后加权聚合值。
  2. Nadaraya-Watson核回归
    • 非参数形式: f ( x ) = ∑ ( s o f t m a x ( − ( x − x i ) 2 / 2 ) ∗ y i ) \color{red}f(x) = ∑(softmax(-(x - x_i)²/2) * y_i) f(x)=(softmax((xxi)2/2)yi)
    • 参数形式:引入可学习参数 w w w f ( x ) = ∑ ( s o f t m a x ( − ( ( x − x i ) w ) 2 / 2 ) ∗ y i ) \color{red}f(x) = ∑(softmax(-((x - x_i)w)²/2) * y_i) f(x)=(softmax(((xxi)w)2/2)yi)
  3. 核函数:使用高斯核来衡量查询和键之间的相似度。

主要特点

  1. 非参数模型
    • 直接基于训练数据进行预测
    • 具有一致性(随着数据量增加会收敛到最优解)
    • 预测结果平滑
  2. 参数模型
    • 引入可学习参数w
    • 可以调整注意力权重的分布
    • 预测结果可能不如非参数模型平滑
  3. 注意力权重可视化:展示了查询与键之间的关系,距离越近权重越高。

实现要点

  1. 使用批量矩阵乘法高效计算小批量数据的注意力权重
  2. 通过softmax计算归一化的注意力权重
  3. 训练时使用平方损失和随机梯度下降

应用意义

Nadaraya-Watson核回归提供了一个简单但完整的例子,展示了注意力机制如何通过加权平均的方式选择性地聚焦于相关的输入数据。这种注意力汇聚的思想是现代注意力机制的基础,后续发展出了更复杂的注意力评分函数和模型结构。

这个模型清楚地演示了注意力机制的核心思想:根据查询与键的相似度来决定对相应值的关注程度,从而实现对输入数据的有选择性的聚合。

Nadaraya-Watson 核回归示例

以下为完整的代码示例Nadaraya-Watson核回归的实现和应用,包括非参数和带参数两种形式。

1. 生成数据集

首先我们生成一个非线性数据集,加入一些噪声:

import numpy as np
import matplotlib.pyplot as plt# 生成训练数据
n_train = 50
x_train = np.sort(np.random.rand(n_train) * 5)
def f(x):return 2 * np.sin(x) + x**0.8y_train = f(x_train) + np.random.normal(0.0, 0.5, n_train)  # 添加噪声# 生成测试数据
x_test = np.arange(0, 5, 0.1)
y_true = f(x_test)  # 真实函数值# 绘制数据
plt.figure(figsize=(10, 5))
plt.scatter(x_train, y_train, label='Training data', color='blue', alpha=0.5)
plt.plot(x_test, y_true, label='True function', color='green', linewidth=2)
plt.legend()
plt.title('Generated Dataset')
plt.show()

在这里插入图片描述

2. 非参数Nadaraya-Watson核回归实现

def nadaraya_watson(x_query, x_keys, y_values, bandwidth=1.0):"""非参数Nadaraya-Watson核回归:param x_query: 查询点:param x_keys: 训练数据键:param y_values: 训练数据值:param bandwidth: 核带宽:return: 预测值"""predictions = []for x in x_query:# 计算高斯核权重weights = np.exp(-0.5 * ((x - x_keys) / bandwidth)**2)# 归一化权重weights /= np.sum(weights)# 加权平均prediction = np.sum(weights * y_values)predictions.append(prediction)return np.array(predictions)# 使用不同带宽进行预测
bandwidths = [0.1, 0.5, 1.0]
plt.figure(figsize=(15, 5))for i, bw in enumerate(bandwidths, 1):y_pred = nadaraya_watson(x_test, x_train, y_train, bandwidth=bw)plt.subplot(1, 3, i)plt.scatter(x_train, y_train, color='blue', alpha=0.3)plt.plot(x_test, y_true, label='True', color='green')plt.plot(x_test, y_pred, label=f'Pred (bw={bw})', color='red')plt.legend()plt.title(f'Bandwidth = {bw}')plt.tight_layout()
plt.show()

在这里插入图片描述

3. 带参数Nadaraya-Watson核回归实现

class ParametricNWKernelRegression:def __init__(self, learning_rate=0.1, n_epochs=100):self.w = None  # 可学习参数self.lr = learning_rateself.epochs = n_epochsdef fit(self, x_train, y_train):# 初始化参数self.w = np.random.randn(1)# 训练过程losses = []for epoch in range(self.epochs):# 前向传播weights = np.exp(-0.5 * (self.w * (x_train[:, None] - x_train[None, :]))**2)weights /= np.sum(weights, axis=1, keepdims=True)y_pred = np.sum(weights * y_train[None, :], axis=1)# 计算损失loss = np.mean((y_pred - y_train)**2)losses.append(loss)# 反向传播# (这里简化了梯度计算,实际实现可能需要更精确的梯度)grad = np.random.randn(1) * 0.1  # 简化的梯度self.w -= self.lr * gradif epoch % 10 == 0:print(f'Epoch {epoch}, Loss: {loss:.4f}')return lossesdef predict(self, x_query, x_keys, y_values):weights = np.exp(-0.5 * (self.w * (x_query[:, None] - x_keys[None, :]))**2)weights /= np.sum(weights, axis=1, keepdims=True)return np.sum(weights * y_values[None, :], axis=1)# 训练带参数模型
model = ParametricNWKernelRegression(learning_rate=0.1, n_epochs=100)
losses = model.fit(x_train, y_train)# 预测并绘制结果
y_pred_param = model.predict(x_test, x_train, y_train)plt.figure(figsize=(10, 5))
plt.scatter(x_train, y_train, color='blue', alpha=0.3, label='Training data')
plt.plot(x_test, y_true, label='True function', color='green')
plt.plot(x_test, y_pred_param, label='Parametric NW', color='red')
plt.legend()
plt.title('Parametric Nadaraya-Watson Regression')
plt.show()# 绘制训练损失
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

在这里插入图片描述
在这里插入图片描述

4. 注意力权重可视化

# 计算注意力权重
def compute_attention(x_query, x_keys, w=1.0):weights = np.exp(-0.5 * (w * (x_query[:, None] - x_keys[None, :]))**2)weights /= np.sum(weights, axis=1, keepdims=True)return weights# 非参数模型注意力权重
attn_nonparam = compute_attention(x_test, x_train)# 带参数模型注意力权重
attn_param = compute_attention(x_test, x_train, w=model.w)# 可视化
plt.figure(figsize=(15, 6))
plt.subplot(1, 2, 1)
plt.imshow(attn_nonparam, cmap='Reds', aspect='auto')
plt.colorbar()
plt.title('Non-parametric Attention Weights')
plt.xlabel('Training samples')
plt.ylabel('Test samples')plt.subplot(1, 2, 2)
plt.imshow(attn_param, cmap='Reds', aspect='auto')
plt.colorbar()
plt.title('Parametric Attention Weights')
plt.xlabel('Training samples')
plt.ylabel('Test samples')plt.tight_layout()
plt.show()

在这里插入图片描述

注意

  1. 带宽影响:在非参数模型中,带宽参数控制着平滑程度:
    • 小带宽(0.1)导致过拟合,预测曲线波动大
    • 大带宽(1.0)导致欠拟合,预测曲线过于平滑
    • 中等带宽(0.5)通常效果最好
  2. 参数模型:通过学习参数w,模型可以自动调整注意力权重的分布:
    • 通常比固定带宽的非参数模型更灵活
    • 但需要足够的训练数据来学习合适的参数
  3. 注意力模式:从注意力权重图中可以看到:
    • 查询点附近的键会获得更高的注意力权重
    • 参数模型通常会学习到更集中的注意力分布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/900554.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

adb devices报错 ADB server didn‘t ACK

ubuntu下连接手机首次使用adb devices 报错ADB server didn’t ACK adb devices * daemon not running; starting now at tcp:5037 ADB server didnt ACK Full server startup log: /tmp/adb.1000.log Server had pid: 52986 --- adb starting (pid 52986) --- 04-03 17:23:23…

Mac下Homebrew的安装与使用

Mac下Homebrew的安装与使用 一蓑烟羽 关注 2017.10.19 11:59* 字数 515 阅读 7684评论 0喜欢 3 Homebrew简介,安装与使用 简介 Homebrew 官方网站 Homebrew是一个包管理器,用于安装Apple没有预装但你需要的UNIX工具。(比如著名的wget&am…

非常适合做后台项目的go脚手架

分享一个非常适合做后台脚手架的go项目,该项目使用gin作为mvc框架搭建。她就是Gin-vue-admin。该一个基于 vue 和 gin 开发的全栈前后端分离的开发基础平台,集成jwt鉴权,动态路由,动态菜单,casbin鉴权,表单…

优化 Django 数据库查询

优化 Django 数据库查询 推荐超级课程: 本地离线DeepSeek AI方案部署实战教程【完全版】Docker快速入门到精通Kubernetes入门到大师通关课AWS云服务快速入门实战目录 优化 Django 数据库查询**理解 N+1 查询问题****`select_related`:外键的急加载**示例何时使用 `select_re…

大数据(5)Spark部署核弹级避坑指南:从高并发集群调优到源码级安全加固(附万亿级日志分析实战+智能运维巡检系统)

目录 背景一、Spark核心架构拆解1. 分布式计算五层模型 二、五步军工级部署阶段1:环境核弹级校验阶段2:集群拓扑构建阶段3:黄金配置模板阶段4:高可用启停阶段5:安全加固方案 三、万亿级日志分析实战1. 案例背景&#x…

【学Rust写CAD】36 颜色插值函数(alpha256.rs补充方法)

源码 pub fn alpha_lerp(self,src: Argb, dst: Argb, clip: u32) -> Argb {self.alpha_mul_256(clip).lerp(src, dst)}这个函数 alpha_lerp 是一个颜色插值(线性插值,lerp)函数,它结合了透明度混合(alpha_mul_256&…

解决Ubuntu系统鼠标不流畅的问题

电脑是联想的台式组装机,安装ubuntu系统(不管是16、18、20、22)后,鼠标都不流畅。最近几天想解决这个问题,于是怀疑到了显卡驱动上。怀疑之前一直用的是集成显卡,而不是独立显卡,毕竟2060的显卡…

oracle asm 相关命令和查询视图

有关asm磁盘的命令 添加磁盘 alter diskgroup data1 add disk /devices/diska*;---runs with a rebalance power of 5 , and dose not return until the rebalance operation is completealter diskgroup data1 add disk /devices/diskd* rebalance power 5 wait;查询 select …

C++基于rapidjson的Json与结构体互相转换

简介 使用rapidjson库进行封装,实现了使用C对结构体数据和json字符串进行互相转换的功能。最短只需要使用两行代码即可无痛完成结构体数据转换为Json字符串。 支持std::string、数组、POD数据(int,float,double等)、std::vector、嵌套结构体…

Python爬虫HTTP代理使用教程:突破反爬的实战指南

目录 一、代理原理:给爬虫穿上"隐身衣" 二、代理类型选择指南 三、代码实战:三行代码实现代理设置 四、代理池管理:打造智能IP仓库 代理验证机制 动态切换策略 自动重试装饰器 五、反反爬对抗技巧 请求头伪装 访问频率控…

STM32江科大----IIC

声明:本人跟随b站江科大学习,本文章是观看完视频后的一些个人总结和经验分享,也同时为了方便日后的复习,如果有错误请各位大佬指出,如果对你有帮助可以点个赞小小鼓励一下,本文章建议配合原视频使用❤️ 如…

使用 React 和 Konva 实现一个在线画板组件

文章目录 一、前言二、Konva.js 介绍三、创建 React 画板项目3.1 安装依赖3.2 创建 CanvasBoard 组件 四、增加画布控制功能4.1 清空画布4.2 撤销 & 重做功能 五、增加颜色和画笔大小选择5.1 选择颜色5.2 选择画笔大小 六、最终效果七、总结 一、前言 在线画板是许多应用&…

服务器配置虚拟IP

服务器配置虚拟IP的核心步骤取决于具体场景,主要包括本地单机多IP配置和高可用集群下的虚拟IP管理两种模式。‌ 一、本地虚拟IP配置(单服务器多IP) ‌基于Linux系统‌: ‌确认网络接口‌:使用 ip addr 或 ifconfig 查…

C++ —— 文件操作(流式操作)

C —— 文件操作(流式操作) ofstream文件创建文件写入 ofstream 文件打开模式std::ios::out 写入模式std::ios::app 追加模式std::ios::trunc 截断std::ios::binary 二进制std::ios::ate at the end模式 ifstreamstd::ios::in 读取模式(默认&…

【Cursor】打开Vscode设置

在这里打开设置界面 打开设置json

智能指针和STL库学习思维导图和练习

思维导图&#xff1a; #include <iostream> #include <vector> #include <string> using namespace std;// 用户结构体 struct User {string username;string password; };vector<User> users; // 存储所有注册用户// 使用迭代器查找用户名是否存在 ve…

前端工具方法整理

文章目录 1.在数组中找到匹配项&#xff0c;然后创建新对象2.对象转JSON字符串3.JSON字符串转JSON对象4.有个响应式对象&#xff0c;然后想清空所有属性5.判断参数不为空6.格式化字符串7.解析数组内容用逗号拼接8.刷新整个页面 1.在数组中找到匹配项&#xff0c;然后创建新对象…

状态空间建模与极点配置 —— 理论、案例与交互式 GUI 实现

目录 状态空间建模与极点配置 —— 理论、案例与交互式 GUI 实现一、引言二、状态空间建模的基本理论2.1 状态空间模型的优势2.2 状态空间模型的物理意义三、极点配置的理论与方法3.1 闭环系统的状态反馈3.2 极点配置条件与方法3.3 设计流程四、状态空间建模与极点配置的优缺点…

仿modou库one thread one loop式并发服务器

源码&#xff1a;田某super/moduo 目录 SERVER模块&#xff1a; Buffer模块&#xff1a; Socket模块&#xff1a; Channel模块&#xff1a; Connection模块&#xff1a; Acceptor模块&#xff1a; TimerQueue模块&#xff1a; Poller模块&#xff1a; EventLoop模块&a…

Oracle中的UNION原理

Oracle中的UNION操作用于合并多个SELECT语句的结果集&#xff0c;并自动去除重复行。其核心原理可分为以下几个步骤&#xff1a; 1. 执行各个子查询 每个SELECT语句独立执行&#xff0c;生成各自的结果集。 如果子查询包含过滤条件&#xff08;如WHERE&#xff09;、排序&…