多任务高斯过程数学原理和Pytorch实现示例

高斯过程其在回归任务中的应用我们都很熟悉了,但是我们一般介绍的都是针对单个任务的,也就是单个输出。本文我们将讨论扩展到多任务gp,强调它们的好处和实际实现。

本文将介绍如何通过共区域化的内在模型(ICM)和共区域化的线性模型(LMC),使用高斯过程对多个相关输出进行建模。

多任务高斯过程

高斯过程是回归和分类任务中的一个强大工具,提供了一种非参数方式来定义函数的分布。当处理多个相关输出时,多任务高斯过程可以模拟这些任务之间的依赖关系,从而带来更好的泛化和预测效果。

数学上,高斯过程被定义为一组随机变量,其中任何有限数量的变量都具有联合高斯分布。对于一组输入点 X,相应的输出值 f(X) 是联合高斯分布的:

其中 m(X) 是均值函数,k(X, X) 是协方差矩阵。

在多任务环境中,目标是建模函数 f: X → R^T,这样我们就有 T 个输出或任务,f_t(X) 对于 t = 1, …, T。这意味着均值函数是 m: X → R^T,核函数是 k: X × X → R^{T × T}。

我们如何模拟这些任务之间的相关性?

独立多任务高斯过程

一个简单的独立多输出高斯过程模型将每个任务独立建模,不考虑任务之间的任何相关性。在这种情况下,每个任务都有自己的高斯过程,具有自己的均值和协方差函数。数学上可以表达为:

这使得协方差矩阵 k(x, x) 是块对角形的,即 diag(k_1(x, x), …, k_T(x, x))。

这种方法没有利用任务之间的共享信息,可能导致性能不佳,尤其是当某些任务的数据有限时。

Intrinsic model of coregionalization(ICM)

ICM(共区域化的内在模型)方法通过引入核心区域化矩阵 (B) 来推广独立多输出高斯过程,该矩阵模型化任务之间的相关性。ICM方法中的协方差函数定义如下:

其中 k_input是在输入空间上定义的协方差函数(例如,平方指数核),而 B ∈ R^{T × T} 是捕捉任务特定协方差的核心区域化矩阵。矩阵 (B) 通常参数化为 (B = W W^T),其中W ∈ R^{T × r*}* ,且 ® 是核心区域化矩阵的秩。这确保了核函数是半正定的。

ICM方法可以学习任务之间的共享结构。任务之间的皮尔逊相关系数可以表示为:

Linear model of coregionalization (LMC)

另一种常见的方法是LMC(线性核心区域化模型)模型,它通过允许更多种类的输入核来扩展ICM。在LMC模型中,协方差函数定义为:

其中 (Q) 是基核的数量,k_input^q 是基核,而 (B_q) 是每个基核的核心区域化矩阵。通过结合多个基核,这个模型可以捕捉任务之间更复杂的相关性。

我们可以通过设置 (Q=1) 来恢复ICM模型。或者说ICM是Q=1的LMC模型

噪声建模

在多任务高斯过程中,我们需要考虑一个多输出似然函数,该函数为每个任务模型化噪声。

标准的似然函数通常是多维高斯似然,可以表示为:

其中 (y) 是观测输出,(f(x)) 是潜在函数,Sigma是噪声协方差矩阵。

这里的灵活性在于噪声协方差矩阵的选择,它可以是对角线Σ=diag(σ12,…,σT2)(每个任务独立噪声)或完整(任务间相关噪声)。

后者通常表示为 Σ=LLT,其中 L∈RT×r,且 r是噪声协方差矩阵的秩。这允许捕捉不同任务的噪声项之间的相关性。

最终包含噪声的协方巧矩阵则由以下给出:

其中 (K) 是没有噪声的协方差矩阵,(I) 是单位矩阵,⊗表示克罗内克乘积,以便将噪声项添加到协方差矩阵的对角块中。

PyTorch实现

我们上面介绍了多任务的高斯过程的数学原理,下面开使用’ GPyTorch '与ICM内核实现多任务GP的示例。

首先需要安装所需的软件包,包括’ Torch ', ’ GPyTorch ', ’ matplotlib ‘和’ seaborn ', ’ numpy '。

 %pip install torch gpytorch matplotlib seaborn numpy pandasimport torchimport gpytorchfrom matplotlib import pyplot as pltimport numpy as npimport seaborn as snsimport pandas as pd

然后,定义多任务GP模型。使用ICM内核(秩r=1)来捕获任务之间的相关性。

我们为两个任务(正弦和移位正弦)生成合成的噪声训练数据,以便有相关的输出。

噪声协方差矩阵为

其中σ²1 = σ²2 = 0.1²,ρ = 0.3。

最后,我们通过绘制每个任务的平均预测和置信区间来训练模型并评估其性能。

 # Define the kernel with coregionalizationclass MultitaskGPModel(gpytorch.models.ExactGP):def __init__(self, train_x, train_y, likelihood, num_tasks):super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)self.mean_module = gpytorch.means.MultitaskMean(gpytorch.means.ConstantMean(), num_tasks=num_tasks)self.covar_module = gpytorch.kernels.MultitaskKernel(gpytorch.kernels.RBFKernel(), num_tasks=num_tasks, rank=1)def forward(self, x):mean_x = self.mean_module(x)covar_x = self.covar_module(x)return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)# Training dataf1 = lambda x:  torch.sin(x * (2 * torch.pi))f2 = lambda x: torch.sin((x - 0.1) * (2 * torch.pi))train_x = torch.linspace(0, 1, 10)train_y = torch.stack([f1(train_x),f2(train_x)]).T# Define the noise covariance matrix with correlation = 0.3sigma2 = 0.1**2Sigma = torch.tensor([[sigma2, 0.3 * sigma2], [0.3 * sigma2, sigma2]])# Add noise to the training datatrain_y += torch.tensor(np.random.multivariate_normal(mean=[0,0], cov=Sigma, size=len(train_x)))# Model and likelihoodnum_tasks = 2likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=num_tasks, rank=1)model = MultitaskGPModel(train_x, train_y, likelihood, num_tasks)# Training the modelmodel.train()likelihood.train()optimizer = torch.optim.Adam(model.parameters(), lr=0.1)mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)num_iter = 500for i in range(num_iter):optimizer.zero_grad()output = model(train_x)loss = -mll(output, train_y)loss.backward()optimizer.step()scheduler.step()# Evaluationmodel.eval()likelihood.eval()test_x = torch.linspace(0, 1, 100)with torch.no_grad(), gpytorch.settings.fast_pred_var():pred_multi = likelihood(model(test_x))# Plot predictionsfig, ax = plt.subplots()colors = ['blue', 'red']for i in range(num_tasks):ax.plot(test_x, pred_multi.mean[:, i], label=f'Mean prediction (Task {i+1})', color=colors[i])ax.plot(test_x, [f1(test_x), f2(test_x)][i], linestyle='--', label=f'True function (Task {i+1})')lower = pred_multi.confidence_region()[0][:, i].detach().numpy()upper = pred_multi.confidence_region()[1][:, i].detach().numpy()ax.fill_between(test_x,lower,upper,alpha=0.2,label=f'Confidence interval (Task {i+1})',color=colors[i])ax.scatter(train_x, train_y[:, 0], color='black', label=f'Training data (Task 1)')ax.scatter(train_x, train_y[:, 1], color='gray', label=f'Training data (Task 2)')ax.set_title('Multitask GP with ICM')ax.legend(loc='lower center', bbox_to_anchor=(0.5, -0.2),ncol=3, fancybox=True)

在使用

GPyTorch

时,通过使用

MultitaskMean

MultitaskKernel

MultitaskGaussianLikelihood

类,ICM模型的实现非常简单。这些类可以处理多任务结构、噪声和核心区域化矩阵,允许我们专注于模型定义和训练。

训练的循环也与标准高斯过程类似,以负边际对数似然作为损失函数,并使用优化器来更新模型参数。

在训练过程中添加了一个调度器来降低学习率,这可以帮助稳定优化过程。

 W = model.covar_module.task_covar_module.covar_factorB = W @ W.Tfig, ax = plt.subplots()sns.heatmap(B.detach().numpy(), annot=True, ax=ax, cbar=False, square=True)ax.set_xticklabels(['Task 1', 'Task 2'])ax.set_yticklabels(['Task 1', 'Task 2'])ax.set_title('Coregionalization matrix B')fig.show()L = model.likelihood.task_noise_covar_factor.detach().numpy()Sigma = L @ L.Tfig, ax = plt.subplots()sns.heatmap(Sigma, annot=True, ax=ax, cbar=False, square=True)ax.set_xticklabels(['Task 1', 'Task 2'])ax.set_yticklabels(['Task 1', 'Task 2'])ax.set_title('Noise covariance matrix')fig.show()

下面的图展示了模型学习的核心区域化矩阵 B 以及噪声协方差矩阵 Σ。

这张图捕捉了任务之间的相关性。如我们所见,B 的非对角线元素是正的。

这个图代表每个任务的噪声水平。注意到模型已经正确学习了噪声相关性。

比较

为了突出使用ICM方法对相关输出建模的优势,我们可以将其与独立处理每个任务、忽略任务之间任何潜在相关性的模型进行比较。

为每个任务定义一个单独的GP,训练它们,并在测试数据上评估它们的性能。

 class IndependentGPModel(gpytorch.models.ExactGP):def __init__(self, train_x, train_y, likelihood):super(IndependentGPModel, self).__init__(train_x, train_y, likelihood)self.mean_module = gpytorch.means.ConstantMean()self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())def forward(self, x):mean_x = self.mean_module(x)covar_x = self.covar_module(x)return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)# Create models and likelihoods for each tasklikelihoods = [gpytorch.likelihoods.GaussianLikelihood() for _ in range(num_tasks)]models = [IndependentGPModel(train_x, train_y[:, i], likelihoods[i]) for i in range(num_tasks)]# Training the independent modelsfor i, (model, likelihood) in enumerate(zip(models, likelihoods)):model.train()likelihood.train()optimizer = torch.optim.Adam(model.parameters(), lr=0.1)mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)for _ in range(num_iter):optimizer.zero_grad()output = model(train_x)loss = -mll(output, train_y[:, i])loss.backward()optimizer.step()scheduler.step()# Evaluationfor model, likelihood in zip(models, likelihoods):model.eval()likelihood.eval()with torch.no_grad(), gpytorch.settings.fast_pred_var():pred_inde = [likelihood(model(test_x)) for model, likelihood in zip(models, likelihoods)]# Plot predictionsfig, ax = plt.subplots()for i in range(num_tasks):ax.plot(test_x, pred_inde[i].mean, label=f'Mean prediction (Task {i+1})', color=colors[i])ax.plot(test_x, [f1(test_x), f2(test_x)][i], linestyle='--', label=f'True function (Task {i+1})')lower = pred_inde[i].confidence_region()[0]upper = pred_inde[i].confidence_region()[1]ax.fill_between(test_x,lower,upper,alpha=0.2,label=f'Confidence interval (Task {i+1})',color=colors[i])ax.scatter(train_x, train_y[:, 0], color='black', label='Training data (Task 1)')ax.scatter(train_x, train_y[:, 1], color='gray', label='Training data (Task 2)')ax.set_title('Independent GPs')ax.legend(loc='lower center', bbox_to_anchor=(0.5, -0.2),ncol=3, fancybox=True)

在性能方面,比较多任务GP与ICM和独立GP在测试数据上预测的均方误差(MSE)。

 mean_multi = pred_multi.mean.numpy()mean_inde = np.stack([pred.mean.numpy() for pred in pred_inde]).Ttest_y = torch.stack([f1(test_x), f2(test_x)]).T.numpy()MSE_multi = np.mean((mean_multi - test_y) ** 2)MSE_inde = np.mean((mean_inde - test_y) ** 2)df = pd.DataFrame({'Model': ['ICM', 'Independent'],'MSE': [MSE_multi, MSE_inde]})df

可以看到由于共区域化矩阵学习到的共享结构,ICM在MSE方面略优于独立gp。当处理更复杂的任务或有限的数据时,这种改进可能更为显著。

在独立GP的场景中,每个GP从一个较小的10个点的数据集中学习,这可能会导致过拟合或次优泛化。具有ICM的多任务GP使用所有20个点来学习指数核参数的平方。这种共享信息有助于改进对这两个任务的预测。

https://avoid.overfit.cn/post/f804e93bd5dd4c4ab9ede5bf1bc9b6c8

作者:Andrea Ruglioni

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/48589.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux知识点汇总】07 Linux系统防火墙相关命令,关闭和开启防火墙、开放端口号

​完整系列文章目录 【Linux知识点汇总】 心血来潮突然想起之前写过的系列文章【Linux知识点汇总】还未完结,那么今天就继续吧 说明:这个系列的内容,在系列【Linux服务器Java环境搭建】中会经常用到,大家可以自行查找相关命令 一、…

Docker搭建本地私有仓库

目录 1.下载运行registry 镜像 2.添加私有镜像仓库地址 3.为镜像添加标签 4.上传到私有仓库 5.查看私有仓库的所有镜像 6.测试私有仓库下载 1.下载运行registry 镜像 docker pull registry docker run -d -v /data/registry:/var/lib/registry -p 5000:5000 --restartal…

PostgreSQL使用(二)——插入、更新、删除数据

说明:本文介绍PostgreSQL的DML语言; 插入数据 -- 1.全字段插入,字段名可以省略 insert into tb_student values (1, 张三, 1990-01-01, 88.88);-- 2.部分字段插入,字段名必须写全 insert into tb_student (id, name) values (2,…

vue3【详解】跨组件通信 -- 依赖注入 provide inject

用于解决跨组件&#xff08;父组件与所有后代&#xff09;数据通信 提供数据 provide 传出数据的组件 &#xff08;通常为父辈组件&#xff09;提供数据 <script setup> import { provide } from vueprovide(/* 注入名 */ message, /* 值 */ hello!) </script>pro…

pycharm中运行.sh文件

最近在跑一个项目代码&#xff0c;里面要运行.sh文件。于是配置了下如何在pycharm中正常运行.sh文件。 首先安装好git&#xff0c;然后 File—>Settings—>Tools—>Terminal—>Shell path&#xff0c;将cmd.exe改成刚刚下载的git的路径&#xff0c;注意选择的是s…

web服务器——虚拟主机配置实战

搭建静态网站 —— 基于 http 协议的静态网站 实验 1 &#xff1a;搭建一个 web 服务器&#xff0c;访问该服务器时显示 “hello world” 欢迎界面 。 实验 2 &#xff1a;建立两个基于 ip 地址访问的网站&#xff0c;要求如下 该网站 ip 地址的主机位为 100 &#xff0c;设置…

web——搭建静态网站——基于http协议的静态网站

实验 4 &#xff1a;建立两个基于域名访问的网站&#xff0c;要求如下&#xff1a; 新建一个网站&#xff0c;域名为 www.ceshi.com &#xff0c;设置网站首页目录为 /www/name &#xff0c;网页内容为 this is test 。 新建一个网站&#xff0c;域名为 rhce.first.day &…

阵列信号处理学习笔记(一)--阵列信号处理定义

阵列信号 阵列信号处理学习笔记&#xff08;一&#xff09;–阵列信号处理定义 阵列信号处理学习笔记&#xff08;二&#xff09;–空域滤波基本原理 文章目录 阵列信号前言一、阵列信号处理定义1.1 信号1.2 阵列 二、雷达数据中哪些属于空间采样总结 前言 MOOC 阵列信号处理…

在线 PDF 制作者泄露用户上传的文档

两家在线 PDF 制作者泄露了数万份用户文档&#xff0c;包括护照、驾驶执照、证书以及用户上传的其他个人信息。 我们都经历过这样的情况&#xff1a;非常匆忙&#xff0c;努力快速制作 PDF 并提交表单。许多人向在线 PDF 制作者寻求帮助&#xff0c;许多人的祈祷得到了回应。 …

FOG Project 文件名命令注入漏洞复现(CVE-2024-39914)

0x01 产品简介 FOG是一个开源的计算机镜像解决方案,旨在帮助管理员轻松地部署、维护和克隆大量计算机。FOG Project 提供了一套功能强大的工具,使用户能够快速部署操作系统、软件和配置设置到多台计算机上,从而节省时间和精力。该项目支持基于网络的 PXE 启动、镜像创建和还…

【人工智能】AI音乐创作兴起与AI伦理的新视角

文章目录 &#x1f34a;AI音乐创作&#xff1a;一键生成&#xff0c;打造你的专属乐章&#x1f34a;1 市面上的AI音乐应用1.1 Suno AI1.2 网易天音 &#x1f34a;2 AI音乐创作的流程原理(直接制作可跳到第3点)2.1 AI音乐流派2.2 AI音乐风格2.3 AI音乐的结构顺序2.5 选择AI音乐乐…

手机如何播放电脑的声音?

准备工具&#xff1a; 有线耳机&#xff0c;手机&#xff0c;电脑&#xff0c;远控软件 1.有线耳机插电脑上 2.电脑安装pc版远控软件&#xff0c;手机安装手机端控制版远控软件 3.手机控制电脑开启声音控制 用手机控制电脑后&#xff0c;打开声音控制&#xff0c;电脑播放视频…

uniapp vue3 上传视频组件封装

首先创建一个 components 文件在里面进行组件的创建 下面是 vvideo组件的封装 也就是图片上传组件 只是我的命名是随便起的 <template><!-- 上传视频 --><view class"up-page"><!--视频--><view class"show-box" v-for"…

Netty HTTPS服务端高并发宕机案例

读李林峰《netty进阶指南》于第18章有感。特此记录一下问题的现象&#xff0c;以及他是如何排障的&#xff0c;以此加深理解 目录标题 事件梳理排查事后分析如何解决总结 事件梳理 某系统内部两个模块之间采用 HTTPS 通信。 某天&#xff1a; 客户端某时间吞吐量为0&#xf…

海思arm-hisiv400-linux-gcc 交叉编译rsyslog 记录心得

需要编译rsyslog,参考海思3536平台上rsyslog交叉编译、使用-CSDN博客和rsyslog移植&#xff08;亲测成功&#xff09;_rsyslog交叉编译-CSDN博客 首先下载了要用到的一些库的源码&#xff0c;先交叉编译这些库 原来是在centos6上交叉编译的&#xff0c;结果编译时报缺少软件要…

MySQL练习02

题目 步骤 创建数据库 create database mydb8_worker; #创建数据库 use mydb8_worker; #使用数据库 创建表 create table t_worker( department_id int(11) not null comment 部门号, worker_id int(11) primary key not null comment 职工号, worker_date date not …

数据结构 - 栈(精简介绍)

文章目录 普通栈Stack用法Q 最长有效括号 单调栈Q 接雨水 普通栈 栈就是一个先进后出的结构 想象一个容器&#xff0c;往里面一层一层放东西&#xff0c;最早放进去的东西被压在下面&#xff08;所以放元素也叫压栈&#xff09;&#xff0c;要拿到这个最低层的东西需要先把上面…

【系统架构设计师】十三、软件可靠性(基本概念|软件可靠性建模)

目录 一、基本概念 1.1 定义 1.2 软件可靠性的定量描述 1.3 可靠性测试的意义 1.4 广义的软件可靠性测试和狭义的软件可靠性测试 二、软件可靠性建模 2.1 可靠性模型的组成 2.2 可靠性模型的共同假设 2.3 可靠性模型的重要特性 2.4 可靠性建模方法 往期推荐 历年真…

当Excel处理神器EasyExcel遇上Apache POI:一场关于依赖的趣味‘撞车’冒险

目录 前言 报错 解决思想 解决方案 结尾 前言 &#x1f388;&#x1f388;&#x1f388;"Hey there, 大家好&#xff01;我是Blue&#xff0c;今天可不是一般的‘代码奇遇记’&#xff01;我在与EasyExcel这位数据处理界的魔术师共舞时&#xff0c;突然遭遇了前所未…

Unity-URP-SSAO记录

勾选After Opacity Unity-URP管线&#xff0c;本来又一个“bug”, 网上查不到很多关于ssao的资料 以为会不会又是一个极度少人用的东西 而且几乎都是要第三方替代 也完全没有SSAO大概的消耗是多少&#xff0c;完全是黑盒(因为用的人少&#xff0c;研究的人少&#xff0c;优…