最小二乘求解器lstsq,处理带权重和L2正则的线性回归

目录

代码注释版:

关键功能说明:

torch.linalg.cholesky 的原理

代码示例

Cholesky 分解的应用

与 torch.cholesky 的区别

总结


代码注释版:

from typing import Optionalimport torchdef lstsq(matrix: torch.Tensor, rhs: torch.Tensor, weights: torch.Tensor, l2_regularizer: Optional[torch.Tensor] = None,l2_regularizer_rhs: Optional[torch.Tensor] = None,shared: bool = False
) -> torch.Tensor:"""带权重和L2正则化的最小二乘求解器,使用Cholesky分解解决形如 (A^T W A + λI) x = A^T W b 的线性系统支持多任务共享参数(通过shared参数合并Gram矩阵和右侧项)Args:matrix: 设计矩阵A,形状为 [batch_size, n_obs, n_params]rhs: 右侧项b,形状为 [batch_size, n_obs, n_outputs]weights: 权重矩阵W的对角元素,形状为 [batch_size, n_obs]l2_regularizer: L2正则化项λ的对角矩阵,形状为 [batch_size, n_params, n_params]l2_regularizer_rhs: 正则化项对右侧的修正,形状为 [batch_size, n_params, n_outputs]shared: 是否共享参数(将多个系统的Gram矩阵和右侧项求和)Returns:最小二乘解,形状为 [batch_size, n_params, n_outputs]"""# 加权设计矩阵: W^(1/2) * Aweighted_matrix = weights.unsqueeze(-1) * matrix# 计算正则化的Gram矩阵: A^T W A + λIregularized_gramian = weighted_matrix.mT @ matrixif l2_regularizer is not None:regularized_gramian += l2_regularizer  # 添加L2正则项# 计算右侧项: A^T W b + λ_rhsATb = weighted_matrix.mT @ rhsif l2_regularizer_rhs is not None:ATb += l2_regularizer_rhs# 如果共享参数,合并所有batch的贡献if shared:regularized_gramian = regularized_gramian.sum(dim=0, keepdim=True)ATb = ATb.sum(dim=0, keepdim=True)# Cholesky分解求解chol = torch.linalg.cholesky(regularized_gramian)return torch.cholesky_solve(ATb, chol)def lstsq_partial_share(matrix: torch.Tensor,rhs: torch.Tensor,weights: torch.Tensor,l2_regularizer: torch.Tensor,n_shared: int = 0
) -> torch.Tensor:"""部分参数共享的最小二乘求解器将参数分为共享部分和独立部分:- 共享参数在所有样本间共享- 独立参数每个样本单独估计通过分块回归实现高效求解Args:matrix: 设计矩阵A,形状为 [batch_size, n_obs, n_params]rhs: 右侧项b,形状为 [batch_size, n_obs, n_outputs]weights: 权重矩阵的对角元素,形状为 [batch_size, n_obs]l2_regularizer: 正则化强度,形状为 [batch_size, n_params]n_shared: 共享参数的数量Returns:参数矩阵,前n_shared列为共享参数,其余为独立参数形状为 [batch_size, n_params, n_outputs]"""n_params = matrix.shape[-1]n_rhs_outputs = rhs.shape[-1]n_indep = n_params - n_shared# 全共享情况直接返回广播结果if n_indep == 0:result = lstsq(matrix, rhs, weights, l2_regularizer, shared=True)return result.expand(matrix.shape[0], -1, -1)# 将正则化项转换为设计矩阵的扩展部分# 相当于添加 λI 的正则化项matrix = torch.cat([matrix, batch_eye(n_params, matrix.shape[0])], dim=1)rhs = torch.nn.functional.pad(rhs, (0, 0, 0, n_params))  # 右侧添加0weights = torch.cat([weights, l2_regularizer.unsqueeze(0).expand(matrix.shape[0], -1)], dim=1)# 分割共享和独立参数对应的设计矩阵matrix_shared, matrix_indep = torch.split(matrix, [n_shared, n_indep], dim=-1)# 步骤1:求解独立参数对共享参数和输出的影响indep_coeffs = lstsq(matrix_indep, torch.cat([matrix_shared, rhs], dim=-1), weights)coeff_indep2shared, coeff_indep2rhs = torch.split(indep_coeffs, [n_shared, n_rhs_outputs], dim=-1)# 步骤2:用残差求解共享参数shared_residual = matrix_shared - matrix_indep @ coeff_indep2sharedrhs_residual = rhs - matrix_indep @ coeff_indep2rhscoeff_shared2rhs = lstsq(shared_residual, rhs_residual, weights, shared=True)# 步骤3:更新独立参数系数coeff_indep2rhs = coeff_indep2rhs - coeff_indep2shared @ coeff_shared2rhs# 合并结果:共享参数广播,独立参数保持独立coeff_shared2rhs = coeff_shared2rhs.expand(matrix.shape[0], -1, -1)return torch.cat([coeff_shared2rhs, coeff_indep2rhs], dim=1)def batch_eye(n_params: int, batch_size: int) -> torch.Tensor:"""生成批次对角矩阵Args:n_params: 矩阵维度batch_size: 批次大小Returns:形状为 [batch_size, n_params, n_params] 的单位矩阵批次"""return torch.eye(n_params).reshape(1, n_params, n_params).expand(batch_size, -1, -1)

关键功能说明:

  1. lstsq:

    • 核心最小二乘求解器,处理带权重和L2正则的线性回归

    • 使用Cholesky分解提高数值稳定性

    • 支持多任务参数共享模式(shared=True时合并所有任务的贡献)

  2. lstsq_partial_share:

    • 处理部分参数共享的回归问题

    • 通过三步分块回归实现:

      1. 估计独立参数对共享参数和输出的影响

      2. 用残差估计共享参数

      3. 修正独立参数估计值

    • 通过矩阵拼接技巧将正则化转换为设计矩阵扩展

  3. batch_eye:

    • 生成批次单位矩阵,用于构建正则化项

    • 典型应用:将L2正则转换为扩展设计矩阵的伪观测

torch.linalg.cholesky 的原理

torch.linalg.cholesky(A) 用于对对称正定矩阵 AAA 进行 Cholesky 分解,即将其分解为:

A=LLTA = L L^TA=LLT

其中:

  • AAA 是 对称正定矩阵(必须满足 A=ATA = A^TA=AT 且所有特征值大于 0)。

  • LLL 是 下三角矩阵

计算 Cholesky 分解 的方式基于逐行计算 LLL:

  1. 计算对角元素:

    Lii=Aii−∑k=1i−1Lik2L_{ii} = \sqrt{ A_{ii} - \sum_{k=1}^{i-1} L_{ik}^2 }Lii​=Aii​−k=1∑i−1​Lik2​​
  2. 计算非对角元素:

    Lji=1Lii(Aji−∑k=1i−1LjkLik),j>iL_{ji} = \frac{1}{L_{ii}} \left( A_{ji} - \sum_{k=1}^{i-1} L_{jk} L_{ik} \right), \quad j > iLji​=Lii​1​(Aji​−k=1∑i−1​Ljk​Lik​),j>i

这个算法 只需要计算下三角部分,所以比 LU 分解 计算量更少,适用于 正定矩阵的快速求解


代码示例

import torch# 生成一个对称正定矩阵
A = torch.tensor([[4.0, 12.0, -16.0], [12.0, 37.0, -43.0], [-16.0, -43.0, 98.0]])# Cholesky 分解
L = torch.linalg.cholesky(A)
print(L)

输出

tensor([[ 2.0000, 0.0000, 0.0000],

[ 6.0000, 1.0000, 0.0000],

[-8.0000, 5.0000, 3.0000]])

可以验证:

print(torch.mm(L, L.T))# 结果应当等于 A

Cholesky 分解的应用

  1. 解线性方程组 Ax=bAx = bAx=b:

    • 先求 L = torch.linalg.cholesky(A)

    • Ly = b(前代法)

    • L^T x = y(后代法)

  2. 生成多元正态分布

    • 如果协方差矩阵 Σ\SigmaΣ 进行 Cholesky 分解 Σ=LLT\Sigma = L L^TΣ=LLT,

    • 则可以用 L @ torch.randn(n, d) 生成符合协方差 Σ\SigmaΣ 的多元正态分布数据。


torch.cholesky 的区别

  • torch.cholesky(A) 旧版 API,不推荐使用。

  • torch.linalg.cholesky(A) 现代 API,支持 batch 计算,推荐使用。


总结

  • torch.linalg.cholesky(A) 计算 对称正定矩阵Cholesky 分解,分解成下三角矩阵 L,使得 A=LLTA = L L^TA=LLT。

  • 计算方式比 LU 分解更快,主要用于 正定矩阵的求解、统计学、多元正态分布 等。

  • 使用 Cholesky 分解求解线性方程组比直接求逆更稳定高效。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/76441.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI辅助下基于ArcGIS Pro的SWAT模型全流程高效建模实践与深度进阶应用

目前,流域水资源和水生态问题逐渐成为制约社会经济和环境可持续发展的重要因素。SWAT模型是一种基于物理机制的分布式流域水文与生态模拟模型,能够对流域的水循环过程、污染物迁移等过程进行精细模拟和量化分析。SWAT模型目前广泛应用于流域水文过程研究…

DHT11数字温湿度传感器驱动开发全解析(下) | 零基础入门STM32第八十八步

主题内容教学目的/扩展视频DHT11芯片电路连接,手册分析。驱动程序,读出数据。能读出温湿度值即可。 师从洋桃电子,杜洋老师 📑文章目录 一、硬件接口与通信原理1.1 硬件连接拓扑1.2 单总线通信时序 二、驱动代码深度解析&#xff…

24、网络编程基础概念

网络编程基础概念 网络结构模式MAC地址IP地址子网掩码端口网络模型协议网络通信的过程(封装与解封装) 网络结构模式 C/S结构,由客户机和服务器两部分组成,如QQ、英雄联盟 B/S结构,通过浏览器与服务器进程交互&#xf…

【超详细】讲解Ubuntu上如何配置分区方案

Ubuntu 的分区方案 一、通用分区方案(200G为例) EFI系统分区(仅UEFI启动模式需要,) 大小:512MB–1GB类型:主分区(FAT32格式)挂载点:/boot/efi说明&#xff1…

函数的局部变量和全局变量的区分,Kimi的回答

这段代码的目的是通过计算 2**i 和 5**i 的首位数字,并将这两个首位数字的乘积添加到一个集合中,最终返回这些乘积的总和。下面是具体的解释和问题的分析。 sum_t的角色: sum_t 是一个累加器,用来存储所有独特的(不重复…

RNN模型及NLP应用(5/9)——多层RNN、双向RNN、预训练

声明: 本文基于哔站博主【Shusenwang】的视频课程【RNN模型及NLP应用】,结合自身的理解所作,旨在帮助大家了解学习NLP自然语言处理基础知识。配合着视频课程学习效果更佳。 材料来源:【Shusenwang】的视频课程【RNN模型及NLP应用…

【3.软件工程】3.4 原型及相关模型

软件开发模型进化论:从原型驱动到混合模型的完整指南 🔄 一、模型进化关系全景图 #mermaid-svg-GcOFjt54gUs4oPeu {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-GcOFjt54gUs4oPeu .error-i…

硬件与软件的边界-从单片机到linux的问答详解

硬件与软件的边界——从单片机到 Linux 设备驱动的问答详解 在嵌入式开发和操作系统领域,经常会有人问: “如果一个设备里没有任何代码,硬件是不是依然会工作?例如,数据收发、寄存器数据存储、甚至中断触发&#xff…

玛卡巴卡的k8s知识点问答题(七)

25. 说明 Job 与 CronJob 的功能 Job 功能: 用于运行一次性任务(批处理任务),确保一个或多个 Pod 成功完成任务后退出。 适用于数据处理、备份、测试等场景,任务完成后 Pod 不会自动重启。 特点: 任务…

【NLP 51、一些LLM模型结构上的变化】

目录 一、multi-head 共享 二、attention结构 1.传统的Tranformer结构 2.GPTJ —— 平行放置的Transformer结构 三、归一化层位置的选择 1.Post LN: 2.Pre-LN【目前主流】: 3.Sandwich-LN: 四、归一化函数选择 1.传统的归一化函数 LayerNorm …

VS+Qt配置QtXlsx库实现execl文件导入导出(全教程)

一、配置QtXlsx 1.1 下载解压QtXlsxWriter(在github下载即可) 网址:https://github.com/dbzhang800/QtXlsxWriter 1.2 使用qt运行 点击qtxlsx.pro运行QtXlsxWriter 选择DesktopQt51211MSVC201564bit编译器(选择自己本地电脑qt…

Golang的文件处理优化策略

Golang的文件处理优化策略 一、Golang的文件处理优化策略概述 是一门效率高、易于编程的编程语言,它的文件处理能力也非常强大。 在实际开发中,需要注意一些优化策略,以提高文件处理的效率和性能。 本文将介绍Golang中的文件处理优化策略&…

自学-C语言-基础-数组、函数、指针、结构体和共同体、文件

这里写自定义目录标题 代码环境:?问题思考:一、数组二、函数三、指针四、结构体和共同体五、文件问题答案: 代码环境: Dev C ?问题思考: 把上门的字母与下面相同的字母相连,线不能…

VMware+Ubuntu+VScode+ROS一站式教学+常见问题解决

目录 一.VMware的安装 二.Ubuntu下载 1.前言 2.Ubuntu版本选择 三.VMware中Ubuntu的安装 四.Ubuntu系统基本设置 1.中文更改 2.中文输入法更改 3. 辅助工具 vmware tools 五.VScode的安装ros基本插件 1.安装 2.ros辅助插件下载 六.ROS安装 1.安装ros 2.配置ROS…

PostgreSQL pg_repack 重新组织表并释放表空间

pg_repack pg_repack是 PostgreSQL 的一个扩展,它允许您从表和索引中删除膨胀,并可选择恢复聚集索引的物理顺序。与CLUSTER和VACUUM FULL不同,它可以在线工作,在处理过程中无需对已处理的表保持独占锁定。pg_repack 启动效率高&a…

5G_WiFi_CE_射频输出功率、发射功率控制(TPC)和功率密度测试

目录 一、规范要求 1、法规目录: (1)RF Output Power (2)Transmit Power Control (TPC) (3)Power Density 2、限值: 二、EIRP测试方法 (1)测试条件 (2&#xff…

扫描线离散化线段树解决矩形面积并-洛谷P5490

https://www.luogu.com.cn/problem/P5490 题目描述 求 n n n 个四边平行于坐标轴的矩形的面积并。 输入格式 第一行一个正整数 n n n。 接下来 n n n 行每行四个非负整数 x 1 , y 1 , x 2 , y 2 x_1, y_1, x_2, y_2 x1​,y1​,x2​,y2​,表示一个矩形的四个…

Java项目之基于ssm的简易版营业厅宽带系统(源码+文档)

项目简介 简易版营业厅宽带系统实现了以下功能: 此营业厅宽带系统利用当下成熟完善的SSM框架,使用跨平台的可开发大型商业网站的Java语言,以及最受欢迎的RDBMS应用软件之一的Mysql数据库进行程序开发。实现了营业厅宽带系统基础数据的管理&…

从入门到入土,SQLServer 2022慢查询问题总结

列为,由于公司原因,作者接触了一个SQLServer 2022作为数据存储到项目,可能是上一任的哥们儿离开的时候带有情绪,所以现在项目的主要问题就是,所有功能都实现了,但是就是慢,列表页3s打底,客户很生气,经过几周摸爬滚打,作以下总结,作为自己的成长记录。 一、索引问题…

PDF处理控件Aspose.PDF教程:在Python、Java 和 C# 中旋转 PDF 文档

您是否希望快速轻松地在线旋转PDF文档?无论您需要修复文档的方向还是只想重新排列页面,本指南都能满足您的需求。有简单的方法可以解决此问题 - 无论您喜欢在线工具还是编程解决方案。 在本指南中,我们将向您展示如何免费在线旋转 PDF&#…