em算法 实例 正态分布_EM算法解GMM

看了很多介绍EM算法的文章,但是他们都没有代码,所以在这里写出来。

Jensen 不等式

参考期望最大算法

Jensen不等式在优化理论中大量用到,首先来回顾下凸函数和凹函数的定义。假设

是定义域为实数的函数,如果对于所有的
的二阶导数大于等于0,那么
是凸函数。当
是向量时,如果hessian矩阵
是半正定(即
是凸函数)。如果,
的二阶导数小于0或者
就是凹函数。

Jensen不等式描述如下:

  1. 如果
    是凸函数,
    是随机变量,则
    是严格凸函数时,则
  2. 如果
    是凹函数,
    是随机变量,则
    ,当
    是(严格)凹函数当且仅当
    是(严格)凸函

2e7208754ca77ef406afca28d96332a3.png

EM思想

极大似然函数法估计参数的一般步骤:

  1. 写出似然函数
  2. 取对数
  3. 求导数,并令导数为0
  4. 解似然方程

给定

个训练样本
,假设样本之间相互独立,要拟合模型
。根据分布我们可以得到如下的似然函数:

需要对每个样本实例的每个可能的类别

求联合概率分布之和,即

如果

是已知的,那么使用极大似然估计参数
会很容易。

然而上式存在一个不确定的隐含变量(latent random variable)

,这种情况下EM算法就派上用场了。

由于不能直接最大化

,所以只能不断地建立
的下界(E-step),再优化下界。一直迭代直到算法收敛到局部最优解。

EM算法通过引入隐含变量,使用MLE(极大似然估计)进行迭代求解参数。通常引入隐含变量后会有两个参数,EM算法首先会固定其中一个参数,然后使用MLE计算第二个参数;然后固定第二个参数,再使用MLE估计第一个参数的值。依次迭代,直到收敛到局部最优解。

  • E-Step: 通过观察到的状态和现有模型估计参数估计值 隐含状态
  • M-Step: 假设隐含状态已知的情况下,最大化似然函数。

由于算法保证了每次迭代之后,似然函数都会增加,所以函数最终会收敛

EM推导

对于每个实例

,用
表示样本实例 隐含变量
的某种分布,且
满足条件
,如果
是连续的,则
表示概率密度函数,将求和换成积分。

上式最后的变换用到了Jensen不等式:

由于

函数的二阶导数为
,为凹函数,所以使用

把所以上式写成

,那么我们可以通过不断的优化
的下界,来使得
不断提高,最终达到它的最大值。

在Jensen不等式中,当

,即为常数时,等号成立。在这里即为:

变换并对

求和得到:

因为

,概率之和为1,所以:

因此:

可以看出,固定了参数

之后,使下界拉升的
的计算公式就是后验概率,一并解决了
如何选择的问题。

EM完整的流程如下:

  1. 初始化参数分布
  2. 重复E-Step和M-Step直到收敛
    1. E-Step: 根据参数的初始值或者上一次迭代的模型参数来计算出隐含变量的后验概率,其实就是隐含变量的期望值,作为隐含变量的当前估计值:
    2. M-Step: 最大化似然函数从而获得新的参数值:

多维高斯分布

一元高斯分布的概率密度函数为:

因为

是标量,所以
等价于
,所以上式等价于

推广到多维得到多元高斯分布,得到K维随机变量

的概率密度函数:

都是K维向量
是协方差阵的行列式,协方差阵
的正定矩阵,称
服从K元正态分布,简记为:

多元高斯分布的极大似然估计

对于

个样本
,其似然函数为:

分别对

求偏导,参考多元正态分布的极大似然估计得到:

# use multivariate_normal to generate 2d gaussian distribution
mean = [3, 4]
cov = [[1.5, 0], [0, 3.3]]
x = np.random.multivariate_normal(mean, cov, 500)
plt.scatter(x[:, 0], x[:, 1])mu_hat = np.mean(x, axis=0)
print(mu_hat)
sigma_hat = ((x-mu_hat).T @ (x-mu_hat)) / 500
print(sigma_hat)
#[2.89991371 4.08421214]
#[[ 1.43340175 -0.01134683]#[-0.01134683  3.28850441]]

6ce8576a59becc23b3b61dc81f1af6da.png
二元高斯分布

高斯混合模型

生成一维的高斯分布

:

sigma * np.random.randn(...) + mu

生成二维分布需要乘以协方差矩阵(协方差矩阵是正定的,所以可以分解(Cholesky)成下三角矩阵),

二维高斯分布的参数分析

from scipy.stats import multivariate_normaldef gen_gaussian(conv, mean, num=1000):points = np.random.randn(num, 2)points = points @ conv + meanreturn pointsfig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111)conv, mean = np.array([[1, 0], [0, 5]]), np.array([2, 4])
points1 = gen_gaussian(conv, mean)
plt.scatter(points1[:, 0], points1[:, 1])conv, mean = np.array([[2, 0], [0, 3]]), np.array([10, 15])
points2 = gen_gaussian(conv, mean)
plt.scatter(points2[:, 0], points2[:, 1])points = np.append(points1, points2, axis=0)K = 2
X = points
mu = np.array([[2, 4], [10, 15]])
cov = np.array([[[1, 0], [0, 5]], [[2, 0], [0, 3]]])x, y = np.meshgrid(np.sort(X[:, 0]), np.sort(X[:, 1]))
XY = np.array([x.flatten(), y.flatten()]).T
reg_cov = 1e-6 * np.identity(2)    
for m, c in zip(mu, cov):c = c + reg_covmng = multivariate_normal(mean=m, cov=c)ax.contour(np.sort(X[:, 0]), np.sort(X[:, 1]), mng.pdf(XY).reshape(len(X), len(X)), colors='black', alpha=0.3)

cd5d6af3084d771f6c475bdbb6da0c97.png
两个二元高斯分布混合的分布
  1. 定义分量数目
    ,对每个分量设置
    ,然后计算下式的对数似然函数

2. E-Step,根据当前的

计算后验概率
是先验概率,
表示点
属于聚类
的后验概率。

3. M-Step,更新

4. 检查是否收敛,否则转#2

# https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.stats.multivariate_normal.htmlimport math# implement my own gaussian pdf, get same result as multivariate_normal.pdf
def gaussian(X, K, mu, cov):p = 1.0 / np.sqrt(np.power(2*math.pi, K) * np.linalg.det(cov))i = (X-mu).T @ np.linalg.inv(cov) @ (X-mu)p *= np.power(np.e, -0.5*i)return pX = np.random.rand(2,)
mu = np.random.rand(2, )
cov = np.array([[1, 0], [0, 3]])mng = multivariate_normal(mean=mu, cov=cov)
gaussian(X, 2, mu, cov), mng.pdf(X)

输出 (0.08438545576275427, 0.08438545576275427)

# This import registers the 3D projection, but is otherwise unused.
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 unused import
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
import numpy as npfig = plt.figure()
ax = fig.gca(projection='3d')# Make data.
X = np.arange(-5, 5, 0.25)
Y = np.arange(-5, 5, 0.25)
X, Y = np.meshgrid(X, Y)
Z = np.array([gaussian(x, 2, mu, cov) for x in zip(X.flatten(), Y.flatten())]).reshape(X.shape)print(X.shape, Y.shape, Z.shape)
# Plot the surface.
surf = ax.plot_surface(X, Y, Z, cmap=plt.cm.coolwarm,linewidth=0, antialiased=False)

b7ad88f589b50c3d350e649b4d98ec36.png
二元高斯分布的概率密度函数
from tqdm import tqdmX = points
N, K, D = len(points), 2, len(X[0])
mu = np.random.randint(min(X[:,0]),max(X[:,0]),size=(K, D)) 
d = np.max(X)
rcov = 1e-6*np.identity(D)
cov = np.zeros((K, D, D))
for dim in range(K):np.fill_diagonal(cov[dim], d)pi0 = np.random.rand()
pi = np.array([pi0, 1-pi0])rnk = np.zeros((N, K))
muh, covh, Rh = [], [], []log_likelihoods = []for i in tqdm(range(100)):muh.append(mu)covh.append(cov)# E-Steprnk = np.zeros((N, K))    for m, co, p, k in zip(mu, cov, pi, range(K)):co = co + rcovmng = multivariate_normal(mean=m, cov=co)d = np.sum([pi_k * multivariate_normal(mean=mu_k, cov=cov_k).pdf(X) for pi_k, mu_k, cov_k in zip(pi, mu, cov+rcov)], axis=0)rnk[:, k] = p * mng.pdf(X) / dRh.append(rnk)#     for n in range(N):        
#         d = sum([pi[k]*gaussian(X[n], K, mu[k], cov[k]) for k in range(K)])
#         for k in range(K):
#             rnk[n, k] = pi[k] * gaussian(X[n], K, mu[k], cov[k]) / d# M-Stepmu, cov, pi = np.zeros((K, D)), np.zeros((K, D, D)), np.zeros((K, 1))    for k in range(K):nk = np.sum(rnk[:, k], axis=0)# new meanmuk = (1/nk) * np.sum(X*rnk[:, k].reshape(N, 1), axis=0)mu[k] = muk# new conv matrixcovk = (rnk[:, k].reshape(N, 1) * (X-muk)).T @ (X-muk) + rcovcov[k] = covk / nk# new pipi[k] = nk / np.sum(rnk)log_likelihoods.append(np.log(np.sum([p*multivariate_normal(mu[i], cov[k]).pdf(X) for p, i, k in zip(pi, range(len(X[0])), range(K))])))plt.plot(log_likelihoods, label='log_likelihoods')
plt.legend()fig = plt.figure()
ax = fig.add_subplot(111)
plt.scatter(points1[:, 0], points1[:, 1])
plt.scatter(points2[:, 0], points2[:, 1])for m, c in zip(mu, cov):mng = multivariate_normal(mean=m,cov=c)ax.contour(np.sort(X[:, 0]), np.sort(X[:, 1]), mng.pdf(XY).reshape(len(X), len(X)),colors='black',alpha=0.3)

e889006b2812cac4eec8b8034a4c5ddc.png
对数似然函数的收敛过程

b838c48a60e039e4c62b586553908976.gif

参考

WIKI多元正态分布

期望最大算法

二维高斯分布的参数分析

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/433102.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

wifi 小米pro 驱动 黑苹果_搞定小米黑苹果自带WIF,又可省一个USB接口了

首先声明我的是小米笔记本PRO版本的,其他版本的没有经过测试,但理论都是没有问题的,其他版本的朋友,喜欢折腾的话,可以试试!自用版本关于小米笔记本安装黑苹果,网上一直都有很多链接&#xff0c…

代理模式 委派模式 策略模式_策略模式

在策略模式(Strategy Pattern)中,一个类的行为或其算法可以在运行时更改。这种类型的设计模式属于行为型模式。在策略模式中,我们创建表示各种策略的对象和一个行为随着策略对象改变而改变的 context 对象。策略对象改变 context 对象的执行算法。介绍意…

例2-1

#include<stdio.h> int main(void) {printf("Hello World!\n");return 0; } 转载于:https://www.cnblogs.com/520zy/p/3348951.html

java第七章jdbc课后简答题_Java周测题08.13

1.关于Mybatis的描述正确的是&#xff1a;Mybatis是持久层框架&#xff0c;Mybatis封装了JDBC&#xff0c;Mybatis简化了代码的编辑和使用&#xff0c;Mybatis是一个半ORM&#xff08;对象关系映射&#xff09;框架&#xff0c;Mybatis采用了OCP(对象关系映射)的方式封装了数据…

linux中probe函数中传递的参数来源(上)

linux中probe函数传递参数的寻找&#xff08;上&#xff09; 上一篇中&#xff0c;我们追踪了probe函数在何时调用&#xff0c;知道了满足什么条件会调用probe函数&#xff0c;但probe函数中传递的参数我们并不知道在何时定义&#xff0c;到底是谁定义的&#xff0c;反正不是我…

linux中probe函数传递参数的寻找(下)

linux中probe函数传递参数的寻找&#xff08;下&#xff09; 通过追寻driver的脚步&#xff0c;我们有了努力的方向&#xff1a;只有找到spi_bus_type的填充device即可&#xff0c;下面该从device去打通&#xff0c;当两个连通之日&#xff0c;也是任督二脉打通之时。先从设备定…

服务器部署 配置jetty运行参数_Zookeeper+websocket实现对分布式服务器的实时监控...

Zookeeper简介Zookeeper是Hadoop的一个子项目&#xff0c;它是分布式系统中的协调系统。简单来说就是一个Zookeeper注册同步中心&#xff0c;内部结构为一个树形目录&#xff0c;每个节点上可以存放一定量(默认的数据量上限是1M&#xff0c;但是可以通过调整参数修改)的数据&am…

Python Interview Question and Answers

引文&#xff1a;http://ilian.i-n-i.org/python-interview-question-and-answers/ For the last few weeks I have been interviewing several people for Python/Django developers so I thought that it might be helpful to show the questions I am asking together with …

软件工程项目总结_复旦大学软件工程实验室来ASE实验室交流

2020年12月11日下午&#xff0c;复旦大学彭鑫教授一行与我院多智能体软件工程实验室开展科研工作交流。本次交流会议旨在为双方建立沟通桥梁&#xff0c;探讨研究问题&#xff0c;谋划后续合作&#xff0c;促使双方增进了解、加强互动、互相学习、共同进步。学院党委书记、多智…

windows无法发现任何计算机或设备,Win10系统提示windows无法与设备或资源通信如何解决...

最近有win10系统用户发现电脑无法打开网页&#xff0c;然后进行网络诊断的时候&#xff0c;提示“Windows无法与设备或资源(主DNS) 通信”&#xff0c;该怎么解决这样的问题呢&#xff1f;接下来给大家带来Win10系统提示windows无法与设备或资源通信的具体解决步骤。一、更改DN…

scrapy 中不同页面的拼接_scrapy使用技巧总结

1. scrapy运行过程概述scrapy是一个基于python的网络爬虫框架&#xff0c;它读取对指定域名的网页request请求&#xff0c;截取对应域名的返回体&#xff0c;开发者可以编写解析函数&#xff0c;从返回体中抓取自己需要的数据&#xff0c;并对数据进行清洗处理或存入数据库。sc…

Buffers, windows, and tabs

If you’ve moved to Vim from an editor like Notepad or TextMate, you’ll be used to working with the idea of tabs in a text editor in a certain way. Specifically, a tab represents an open file; while the tab’s there, you’ve got an open file, as soon as y…

docker访问宿主机mysql_docker容器内访问宿主机127.0.0.1服务

点击上方”技术生活“&#xff0c;选择“设为星标”做积极的人&#xff0c;而不是积极废人背景原因分析解决方案背景已经通过docker启动的elasticsearch 服务&#xff0c;监听端口9200。在宿主机中直接通过http://127.0.0.1:9200 可以直接访问&#xff0c;但是通过docker访问缺…

ADO.NET+Access: 3,参数 @departmentName 没有默认值

ylbtech-Error-ADO.NETAccess: 3,参数 departmentName 没有默认值。1.A,错误代码返回顶部 3,参数 departmentName 没有默认值。1.B,出错原因分析返回顶部未解决1.C,相关解决方法返回顶部作者&#xff1a;ylbtech出处&#xff1a;http://ylbtech.cnblogs.com/本文版权归作者和博…

lombok有参构造注解_Java高效开发工具: Lombok

Lombok, 一个Java开发必备效率工具&#xff0c;可以大大避免编写一些常用方法(get/set, hashcode等)&#xff0c;简化开发。虽然现在IDE很多都可以通过快捷键生成POJO的一些方法了&#xff0c;但是如果该POJO字段发生变动后&#xff0c;还是需要程序员再次手动重新生成相关方法…

JavaScript操作大全整理(思维导图三--函数基础)

3.JavaScript函数基础 转载于:https://www.cnblogs.com/yuxia/p/3360806.html

nginx指定配置文件启动_NGINX安全加固手册

NIGNX系统安全基线规范1.概述1.1 适用范围本配置标准的使用者包括&#xff1a;各事业部服务器负责人。 各事业部服务器负责人按规范要求进行认证、日志、协议、补丁升级、文件系统管理等方面的安全配置要求。对系统的安全配置审计、加固操作起到指导性作用。1.2 文档内容本文档…

口袋网咖已有服务器在使用怎么注销,口袋网咖_口袋网咖常见问题_口袋网咖专区...

口袋网咖是专门为游戏高玩打造的手机变电脑软件&#xff0c;虚拟电脑神器&#xff0c;体验各种电脑游戏&#xff0c;非常的方便&#xff0c;能让小伙伴尽情的体验手机电脑的感觉&#xff0c;很多小伙伴在使用过程中遇到了一些问题&#xff0c;快啦网为大家分享口袋网咖常见问题…

统计个人已完成的工作量_团队工作量及团队价值贡献统计、核算、评审及提升的重要性...

在推行阿米巴经营模式时&#xff0c;需要进行企业内部产品及服务全价值分析&#xff0c;也就是企业内部团队产品及服务价值增值的全过程分析&#xff0c;团队价值增值是团队存在的目的和意义&#xff0c;对于团队经营来讲&#xff0c;团队工作量就团队的收入&#xff0c;团队价…

hyper服务器虚拟网卡和实际网卡,Hyper-V 3 虚拟网卡带宽应用限制

Windows Server 2012的Hyper-V 3中&#xff0c;打来了系列新功能&#xff0c;例如网卡流量限制功能。 基础架构注意的问题宿主服务器规划过程中&#xff0c;管理员主要考虑服务器基础架构中的CPU、内存、磁盘空间等必要因素&#xff0c;但是网络适配器(简称网卡)通常属于被忽略…