em算法 实例 正态分布_EM算法解GMM

看了很多介绍EM算法的文章,但是他们都没有代码,所以在这里写出来。

Jensen 不等式

参考期望最大算法

Jensen不等式在优化理论中大量用到,首先来回顾下凸函数和凹函数的定义。假设

是定义域为实数的函数,如果对于所有的
的二阶导数大于等于0,那么
是凸函数。当
是向量时,如果hessian矩阵
是半正定(即
是凸函数)。如果,
的二阶导数小于0或者
就是凹函数。

Jensen不等式描述如下:

  1. 如果
    是凸函数,
    是随机变量,则
    是严格凸函数时,则
  2. 如果
    是凹函数,
    是随机变量,则
    ,当
    是(严格)凹函数当且仅当
    是(严格)凸函

2e7208754ca77ef406afca28d96332a3.png

EM思想

极大似然函数法估计参数的一般步骤:

  1. 写出似然函数
  2. 取对数
  3. 求导数,并令导数为0
  4. 解似然方程

给定

个训练样本
,假设样本之间相互独立,要拟合模型
。根据分布我们可以得到如下的似然函数:

需要对每个样本实例的每个可能的类别

求联合概率分布之和,即

如果

是已知的,那么使用极大似然估计参数
会很容易。

然而上式存在一个不确定的隐含变量(latent random variable)

,这种情况下EM算法就派上用场了。

由于不能直接最大化

,所以只能不断地建立
的下界(E-step),再优化下界。一直迭代直到算法收敛到局部最优解。

EM算法通过引入隐含变量,使用MLE(极大似然估计)进行迭代求解参数。通常引入隐含变量后会有两个参数,EM算法首先会固定其中一个参数,然后使用MLE计算第二个参数;然后固定第二个参数,再使用MLE估计第一个参数的值。依次迭代,直到收敛到局部最优解。

  • E-Step: 通过观察到的状态和现有模型估计参数估计值 隐含状态
  • M-Step: 假设隐含状态已知的情况下,最大化似然函数。

由于算法保证了每次迭代之后,似然函数都会增加,所以函数最终会收敛

EM推导

对于每个实例

,用
表示样本实例 隐含变量
的某种分布,且
满足条件
,如果
是连续的,则
表示概率密度函数,将求和换成积分。

上式最后的变换用到了Jensen不等式:

由于

函数的二阶导数为
,为凹函数,所以使用

把所以上式写成

,那么我们可以通过不断的优化
的下界,来使得
不断提高,最终达到它的最大值。

在Jensen不等式中,当

,即为常数时,等号成立。在这里即为:

变换并对

求和得到:

因为

,概率之和为1,所以:

因此:

可以看出,固定了参数

之后,使下界拉升的
的计算公式就是后验概率,一并解决了
如何选择的问题。

EM完整的流程如下:

  1. 初始化参数分布
  2. 重复E-Step和M-Step直到收敛
    1. E-Step: 根据参数的初始值或者上一次迭代的模型参数来计算出隐含变量的后验概率,其实就是隐含变量的期望值,作为隐含变量的当前估计值:
    2. M-Step: 最大化似然函数从而获得新的参数值:

多维高斯分布

一元高斯分布的概率密度函数为:

因为

是标量,所以
等价于
,所以上式等价于

推广到多维得到多元高斯分布,得到K维随机变量

的概率密度函数:

都是K维向量
是协方差阵的行列式,协方差阵
的正定矩阵,称
服从K元正态分布,简记为:

多元高斯分布的极大似然估计

对于

个样本
,其似然函数为:

分别对

求偏导,参考多元正态分布的极大似然估计得到:

# use multivariate_normal to generate 2d gaussian distribution
mean = [3, 4]
cov = [[1.5, 0], [0, 3.3]]
x = np.random.multivariate_normal(mean, cov, 500)
plt.scatter(x[:, 0], x[:, 1])mu_hat = np.mean(x, axis=0)
print(mu_hat)
sigma_hat = ((x-mu_hat).T @ (x-mu_hat)) / 500
print(sigma_hat)
#[2.89991371 4.08421214]
#[[ 1.43340175 -0.01134683]#[-0.01134683  3.28850441]]

6ce8576a59becc23b3b61dc81f1af6da.png
二元高斯分布

高斯混合模型

生成一维的高斯分布

:

sigma * np.random.randn(...) + mu

生成二维分布需要乘以协方差矩阵(协方差矩阵是正定的,所以可以分解(Cholesky)成下三角矩阵),

二维高斯分布的参数分析

from scipy.stats import multivariate_normaldef gen_gaussian(conv, mean, num=1000):points = np.random.randn(num, 2)points = points @ conv + meanreturn pointsfig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111)conv, mean = np.array([[1, 0], [0, 5]]), np.array([2, 4])
points1 = gen_gaussian(conv, mean)
plt.scatter(points1[:, 0], points1[:, 1])conv, mean = np.array([[2, 0], [0, 3]]), np.array([10, 15])
points2 = gen_gaussian(conv, mean)
plt.scatter(points2[:, 0], points2[:, 1])points = np.append(points1, points2, axis=0)K = 2
X = points
mu = np.array([[2, 4], [10, 15]])
cov = np.array([[[1, 0], [0, 5]], [[2, 0], [0, 3]]])x, y = np.meshgrid(np.sort(X[:, 0]), np.sort(X[:, 1]))
XY = np.array([x.flatten(), y.flatten()]).T
reg_cov = 1e-6 * np.identity(2)    
for m, c in zip(mu, cov):c = c + reg_covmng = multivariate_normal(mean=m, cov=c)ax.contour(np.sort(X[:, 0]), np.sort(X[:, 1]), mng.pdf(XY).reshape(len(X), len(X)), colors='black', alpha=0.3)

cd5d6af3084d771f6c475bdbb6da0c97.png
两个二元高斯分布混合的分布
  1. 定义分量数目
    ,对每个分量设置
    ,然后计算下式的对数似然函数

2. E-Step,根据当前的

计算后验概率
是先验概率,
表示点
属于聚类
的后验概率。

3. M-Step,更新

4. 检查是否收敛,否则转#2

# https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.stats.multivariate_normal.htmlimport math# implement my own gaussian pdf, get same result as multivariate_normal.pdf
def gaussian(X, K, mu, cov):p = 1.0 / np.sqrt(np.power(2*math.pi, K) * np.linalg.det(cov))i = (X-mu).T @ np.linalg.inv(cov) @ (X-mu)p *= np.power(np.e, -0.5*i)return pX = np.random.rand(2,)
mu = np.random.rand(2, )
cov = np.array([[1, 0], [0, 3]])mng = multivariate_normal(mean=mu, cov=cov)
gaussian(X, 2, mu, cov), mng.pdf(X)

输出 (0.08438545576275427, 0.08438545576275427)

# This import registers the 3D projection, but is otherwise unused.
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 unused import
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
import numpy as npfig = plt.figure()
ax = fig.gca(projection='3d')# Make data.
X = np.arange(-5, 5, 0.25)
Y = np.arange(-5, 5, 0.25)
X, Y = np.meshgrid(X, Y)
Z = np.array([gaussian(x, 2, mu, cov) for x in zip(X.flatten(), Y.flatten())]).reshape(X.shape)print(X.shape, Y.shape, Z.shape)
# Plot the surface.
surf = ax.plot_surface(X, Y, Z, cmap=plt.cm.coolwarm,linewidth=0, antialiased=False)

b7ad88f589b50c3d350e649b4d98ec36.png
二元高斯分布的概率密度函数
from tqdm import tqdmX = points
N, K, D = len(points), 2, len(X[0])
mu = np.random.randint(min(X[:,0]),max(X[:,0]),size=(K, D)) 
d = np.max(X)
rcov = 1e-6*np.identity(D)
cov = np.zeros((K, D, D))
for dim in range(K):np.fill_diagonal(cov[dim], d)pi0 = np.random.rand()
pi = np.array([pi0, 1-pi0])rnk = np.zeros((N, K))
muh, covh, Rh = [], [], []log_likelihoods = []for i in tqdm(range(100)):muh.append(mu)covh.append(cov)# E-Steprnk = np.zeros((N, K))    for m, co, p, k in zip(mu, cov, pi, range(K)):co = co + rcovmng = multivariate_normal(mean=m, cov=co)d = np.sum([pi_k * multivariate_normal(mean=mu_k, cov=cov_k).pdf(X) for pi_k, mu_k, cov_k in zip(pi, mu, cov+rcov)], axis=0)rnk[:, k] = p * mng.pdf(X) / dRh.append(rnk)#     for n in range(N):        
#         d = sum([pi[k]*gaussian(X[n], K, mu[k], cov[k]) for k in range(K)])
#         for k in range(K):
#             rnk[n, k] = pi[k] * gaussian(X[n], K, mu[k], cov[k]) / d# M-Stepmu, cov, pi = np.zeros((K, D)), np.zeros((K, D, D)), np.zeros((K, 1))    for k in range(K):nk = np.sum(rnk[:, k], axis=0)# new meanmuk = (1/nk) * np.sum(X*rnk[:, k].reshape(N, 1), axis=0)mu[k] = muk# new conv matrixcovk = (rnk[:, k].reshape(N, 1) * (X-muk)).T @ (X-muk) + rcovcov[k] = covk / nk# new pipi[k] = nk / np.sum(rnk)log_likelihoods.append(np.log(np.sum([p*multivariate_normal(mu[i], cov[k]).pdf(X) for p, i, k in zip(pi, range(len(X[0])), range(K))])))plt.plot(log_likelihoods, label='log_likelihoods')
plt.legend()fig = plt.figure()
ax = fig.add_subplot(111)
plt.scatter(points1[:, 0], points1[:, 1])
plt.scatter(points2[:, 0], points2[:, 1])for m, c in zip(mu, cov):mng = multivariate_normal(mean=m,cov=c)ax.contour(np.sort(X[:, 0]), np.sort(X[:, 1]), mng.pdf(XY).reshape(len(X), len(X)),colors='black',alpha=0.3)

e889006b2812cac4eec8b8034a4c5ddc.png
对数似然函数的收敛过程

b838c48a60e039e4c62b586553908976.gif

参考

WIKI多元正态分布

期望最大算法

二维高斯分布的参数分析

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/433102.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

html5 webview,HTML5+学习历程之webview经典案例

看了这么多app,其实基本布局使用的最多的无非两种,如下图:类似微信类似QQ在这里小编简单说下这两种布局简单的实现思路,当然如果你还有更好的方法,请在下面留言,让更多人知道你更好的方法!第一种…

C函数的实现(strcpy,atoi,atof,itoa,reverse)

在笔试面试中经常会遇到让你实现C语言中的一些函数比如strcpy,atoi等 1. atoi 把字符串s转换成数字 int Atoi( char *s ) {int num 0, i 0;int sign 1;for( i0; isspace(s[i]); i );sign (s[i] -)? -1:1;if( s[i] || s[i] - )i;for( ;isdigit(s[i]); i ){n…

bbb sdk6 ll_rw_block分析

ll_rw_block是文件系统对下访问实际的块设备驱动的接口,应用程序对实际文件(非设备文件)的操作,最终都是 通过文件系统来调用ll_rw_block来操作实际的存储设备的。 当然ll_rw_block的实际作用远非一个接口那么简单,他…

wifi 小米pro 驱动 黑苹果_搞定小米黑苹果自带WIF,又可省一个USB接口了

首先声明我的是小米笔记本PRO版本的,其他版本的没有经过测试,但理论都是没有问题的,其他版本的朋友,喜欢折腾的话,可以试试!自用版本关于小米笔记本安装黑苹果,网上一直都有很多链接&#xff0c…

教师资格证计算机考察知识点,教师资格证考试信息技术常考知识点同步练习题.docx...

教师资格证考试信息技术常考知识点同步练习题一、信息的定义及特征( 一) 信息定义信息是通过文字、数字、图像、图形、声音、视频等方式进行传播的内容。说明:信息定义考查的方式有两类:一类是选出四个选项中是信息的 ; 另一类是判断选择题,选…

machine learning for hacker记录(4) 智能邮箱(排序学习推荐系统)

本章是上一章邮件过滤技术的延伸,上一章的内容主要是过滤掉垃圾邮件,而这里要讲的是对那些正常的邮件是否可以加入个性化元素,由于每个用户关心的主题并非一样(有人喜欢技术类型的邮件或者购物促销方便的内容邮件等)。…

代理模式 委派模式 策略模式_策略模式

在策略模式(Strategy Pattern)中,一个类的行为或其算法可以在运行时更改。这种类型的设计模式属于行为型模式。在策略模式中,我们创建表示各种策略的对象和一个行为随着策略对象改变而改变的 context 对象。策略对象改变 context 对象的执行算法。介绍意…

云南计算机专业知识真题,2014年云南省事业单位考试专计算机专业知识模拟真题.doc...

2014年云南省事业单位考试专计算机专业知识模拟真题1 在Word中替换的快捷键是____。A、CTRLFB、CTRLHC、CTRLSD、CTRLP2 在Word中打印的快捷键是____。A、CTRLFB、CTRLHC、CTRLOD、CTRLP3 在Word中打开新文档的快捷键是____。A、CTRLFB、CTRLHC、CTRLOD、CTRLP4 在Word中&#…

bbb mmc_blk_probe 分析

bbb 的 emmc驱动在drivers\mmc\card\block.c,其mmc_dirver结构体如下, 根据以往平台总线驱动模型的经验来看的话,内核里应该有mmc_devices结构体,并且 其name也为"mmcblk",这样其probe函数将被调用&#x…

培智学校计算机课教案,培智数学教案

教学内容:11—20以内数的认识 写数 教学目的:1、使学生能初步地数、读、写(本节课重点看图写20以内的数。) 2、初步会写小棒图、数位表上的数,掌握20以内数的顺序。3、初步简单掌握20以内数的组成。 教学重点:学生看图会数数量并会…

例2-1

#include<stdio.h> int main(void) {printf("Hello World!\n");return 0; } 转载于:https://www.cnblogs.com/520zy/p/3348951.html

java第七章jdbc课后简答题_Java周测题08.13

1.关于Mybatis的描述正确的是&#xff1a;Mybatis是持久层框架&#xff0c;Mybatis封装了JDBC&#xff0c;Mybatis简化了代码的编辑和使用&#xff0c;Mybatis是一个半ORM&#xff08;对象关系映射&#xff09;框架&#xff0c;Mybatis采用了OCP(对象关系映射)的方式封装了数据…

linux中probe函数中传递的参数来源(上)

linux中probe函数传递参数的寻找&#xff08;上&#xff09; 上一篇中&#xff0c;我们追踪了probe函数在何时调用&#xff0c;知道了满足什么条件会调用probe函数&#xff0c;但probe函数中传递的参数我们并不知道在何时定义&#xff0c;到底是谁定义的&#xff0c;反正不是我…

2018高职计算机474分排名,2018年高职分类考试招生录取分数线出炉

原标题&#xff1a;2018年高职分类考试招生录取分数线出炉记者 李洁昨天&#xff0c;实况新闻—重庆时报记者从市教育考试院获悉&#xff0c;2018年我市高等职业教育分类考试招生录取最低控制分数线已划定。一、普高类(一)普通文理类1.专本贯通分段培养项目批文史类&#xff1a…

linux中probe函数传递参数的寻找(下)

linux中probe函数传递参数的寻找&#xff08;下&#xff09; 通过追寻driver的脚步&#xff0c;我们有了努力的方向&#xff1a;只有找到spi_bus_type的填充device即可&#xff0c;下面该从device去打通&#xff0c;当两个连通之日&#xff0c;也是任督二脉打通之时。先从设备定…

服务器部署 配置jetty运行参数_Zookeeper+websocket实现对分布式服务器的实时监控...

Zookeeper简介Zookeeper是Hadoop的一个子项目&#xff0c;它是分布式系统中的协调系统。简单来说就是一个Zookeeper注册同步中心&#xff0c;内部结构为一个树形目录&#xff0c;每个节点上可以存放一定量(默认的数据量上限是1M&#xff0c;但是可以通过调整参数修改)的数据&am…

Python Interview Question and Answers

引文&#xff1a;http://ilian.i-n-i.org/python-interview-question-and-answers/ For the last few weeks I have been interviewing several people for Python/Django developers so I thought that it might be helpful to show the questions I am asking together with …

2018年海南计算机职称考试,海南省2018年全国计算机等级考试报名时间

关于延长2018年3月全国计算机等级考试报名时间的公告2018年3月全国计算机等级考试报名时间原定为2017年12月11日-26日&#xff0c;为了满足广大考生报考的需要&#xff0c;现决定将报名时间延长至2017年12月29日17&#xff1a;00。请符合报考条件的考生及时上网填报报名信息和缴…

linux中 probe函数的何时调用的?

linux中 probe函数何时调用的 所以的驱动教程上都说&#xff1a;只有设备和驱动的名字匹配&#xff0c;BUS就会调用驱动的probe函数&#xff0c;但是有时我们要看看probe函数里面到底做了什么&#xff0c;还有传递给probe函数的参数我们就不知道在哪定义&#xff08;反正不是我…

软件工程项目总结_复旦大学软件工程实验室来ASE实验室交流

2020年12月11日下午&#xff0c;复旦大学彭鑫教授一行与我院多智能体软件工程实验室开展科研工作交流。本次交流会议旨在为双方建立沟通桥梁&#xff0c;探讨研究问题&#xff0c;谋划后续合作&#xff0c;促使双方增进了解、加强互动、互相学习、共同进步。学院党委书记、多智…