机器学习算法总结--K均值算法

参考自:

  • 《机器学习》
  • 机器学习&数据挖掘笔记_16(常见面试之机器学习算法思想简单梳理)
  • K-Means Clustering
  • 斯坦福大学公开课 :机器学习课程

简介

K-均值是最普及的聚类算法,算法接受一个未标记的数据集,然后将数据集聚类成不同的组。

K-均值是一个迭代算法,假设我们想要将数据聚类成n个组,其方法为:

  1. 首先选择K个随机的点,称其为聚类中心
  2. 对于数据集中的每一个数据,按照距离K个中心点的距离,将其与距离最近的中心点关联起来,与同一个中心点关联的所有点聚成一个类
  3. 计算每一个组的平均值,将该组所关联的中心点移动到平均值的位置
  4. 重复步骤2-3,直到中心点不再变化

这个过程中分两个主要步骤,第一个就是第二步,将训练集中的样本点根据其与聚类中心的距离,分配到距离最近的聚类中心处,接着第二个就是第三步,更新类中心,做法是计算每个类的所有样本的平均值,然后将这个平均值作为新的类中心值,接着继续这两个步骤,直到达到终止条件,一般是指达到设定好的迭代次数。

当然在这个过程中可能遇到有聚类中心是没有分配数据点给它的,通常的一个做法是删除这种聚类中心,或者是重新选择聚类中心,保证聚类中心数还是初始设定的K个

优化目标

K-均值最小化问题,就是最小化所有的数据点与其所关联的聚类中心之间的距离之和,因此K-均值的代价函数(又称为畸变函数)为: 

J(c(1),c(2),,c(m),μ1,μ2,,μm)=1mmi=1||x(i)μc(i)||2

其中 μc(i)代表与 x(i)最近的聚类中心点。

所以我们的优化目标是找出是的代价函数最小的c(1),c(2),,c(m)μ1,μ2,,μm

minc(1),c(2),,c(m),μ1,μ2,,μmJ(c(1),c(2),,c(m),μ1,μ2,,μm)

回顾K-均值迭代算法的过程可知,第一个循环就是用于减小 c(i)引起的代价,而第二个循环则是用于减小 μi引起的代价,因此, 迭代的过程一定会是每一次迭代都在减小代价函数,不然便是出现了错误。

随机初始化

在运行K-均值算法之前,首先需要随机初始化所有的聚类中心点,做法如下:

  1. 首先应该选择K<m,即聚类中心点的个数要小于所有训练集实例的数量
  2. 随机选择K个训练实例,然后令K个聚类中心分别于这K个训练实例相等

K-均值的一个问题在于,它有可能会停留在一个局部最小值处,而这取决于初始化的情况。

为了解决这个问题,通常需要多次运行K-均值算法,每一次都重新进行随机初始化,最后再比较多次运行K-均值的结果,选择代价函数最小的结果。这种方法在K较小(2-10)的时候还是可行的,但是如果K较大,这种做法可能不会有明显地改善。

优缺点

优点

  1. k-means算法是解决聚类问题的一种经典算法,算法简单、快速
  2. 对处理大数据集,该算法是相对可伸缩的和高效率的,因为它的复杂度大约是O(nkt),其中n是所有对象的数目,k是簇的数目,t是迭代的次数。通常k<<n。这个算法通常局部收敛
  3. 算法尝试找出使平方误差函数值最小的k个划分。当簇是密集的、球状或团状的,且簇与簇之间区别明显时,聚类效果较好。

缺点

  1. k-平均方法只有在簇的平均值被定义的情况下才能使用,且对有些分类属性的数据不适合。
  2. 要求用户必须事先给出要生成的簇的数目k<script type="math/tex" id="MathJax-Element-16">k</script>。
  3. 对初值敏感,对于不同的初始值,可能会导致不同的聚类结果。
  4. 不适合于发现非凸面形状的簇,或者大小差别很大的簇
  5. 对于“噪声”和孤立点数据敏感,少量的该类数据能够对平均值产生极大影响。

代码实现

代码参考自K-Means Clustering。

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time    : 2016/10/21 16:35
@Author  : cai实现 K-Means 聚类算法
"""import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
import os# 寻址最近的中心点
def find_closest_centroids(X, centroids):m = X.shape[0]k = centroids.shape[0]idx = np.zeros(m)for i in range(m):min_dist = 1000000for j in range(k):# 计算每个训练样本和中心点的距离dist = np.sum((X[i, :] - centroids[j, :]) ** 2)if dist < min_dist:# 记录当前最短距离和其中心的索引值min_dist = distidx[i] = jreturn idx# 计算聚类中心
def compute_centroids(X, idx, k):m, n = X.shapecentroids = np.zeros((k, n))for i in range(k):indices = np.where(idx == i)# 计算下一个聚类中心,这里简单的将该类中心的所有数值求平均值作为新的类中心centroids[i, :] = (np.sum(X[indices, :], axis=1) / len(indices[0])).ravel()return centroids# 初始化聚类中心
def init_centroids(X, k):m, n = X.shapecentroids = np.zeros((k, n))# 随机初始化 k 个 [0,m]的整数idx = np.random.randint(0, m, k)for i in range(k):centroids[i, :] = X[idx[i], :]return centroids# 实现 kmeans 算法
def run_k_means(X, initial_centroids, max_iters):m, n = X.shape# 聚类中心的数目k = initial_centroids.shape[0]idx = np.zeros(m)centroids = initial_centroidsfor i in range(max_iters):idx = find_closest_centroids(X, centroids)centroids = compute_centroids(X, idx, k)return idx, centroidsdataPath = os.path.join('data', 'ex7data2.mat')
data = loadmat(dataPath)
X = data['X']initial_centroids = init_centroids(X, 3)
# print(initial_centroids)
# idx = find_closest_centroids(X, initial_centroids)
# print(idx)# print(compute_centroids(X, idx, 3))idx, centroids = run_k_means(X, initial_centroids, 10)
# 可视化聚类结果
cluster1 = X[np.where(idx == 0)[0], :]
cluster2 = X[np.where(idx == 1)[0], :]
cluster3 = X[np.where(idx == 2)[0], :]fig, ax = plt.subplots(figsize=(12, 8))
ax.scatter(cluster1[:, 0], cluster1[:, 1], s=30, color='r', label='Cluster 1')
ax.scatter(cluster2[:, 0], cluster2[:, 1], s=30, color='g', label='Cluster 2')
ax.scatter(cluster3[:, 0], cluster3[:, 1], s=30, color='b', label='Cluster 3')
ax.legend()
plt.show()# 载入一张测试图片,进行测试
imageDataPath = os.path.join('data', 'bird_small.mat')
image = loadmat(imageDataPath)
# print(image)A = image['A']
print(A.shape)# 对图片进行归一化
A = A / 255.# 重新调整数组的尺寸
X = np.reshape(A, (A.shape[0] * A.shape[1], A.shape[2]))
# 随机初始化聚类中心
initial_centroids = init_centroids(X, 16)
# 运行聚类算法
idx, centroids = run_k_means(X, initial_centroids, 10)# 得到最后一次的最近中心点
idx = find_closest_centroids(X, centroids)
# map each pixel to the centroid value
X_recovered = centroids[idx.astype(int), :]
# reshape to the original dimensions
X_recovered = np.reshape(X_recovered, (A.shape[0], A.shape[1], A.shape[2]))# plt.imshow(X_recovered)
# plt.show()

完整代码例子和数据可以查看Kmeans练习代码。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/408926.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

通过日期生成星期几

原理是通过基姆拉尔森计算公式,来根据日期得出星期几 以下是C语言的实现 #include "stdafx.h" #include<iostream> using namespace std; int main() {int year,month,day1,m;;char *cWeekName[] {"星期日","星期一","星期二",…

php超市结算,超市物品结算简易程序代码

System.out.println("购买物品\t" "单价\t" "个数\t" "金额");Scanner in new Scanner(System.in);String String1 ;int a 0;int b 0;int c 0;double sum0;do {System.out.println("请选择你购买的物品");String aSt…

依然老问题:装系统

装windows系统&#xff1a; http://tieba.baidu.com/p/2282428641 装ubuntu: 1.使用 universal-usb-installer制作安装U盘 2.修改BIOS启动顺序为U盘启动优先 3.分区 转载于:https://www.cnblogs.com/owenbeta/archive/2013/04/25/3042528.html

机器学习算法总结--提升方法

参考自&#xff1a; 《统计学习方法》浅谈机器学习基础&#xff08;上&#xff09;Ensemble learning:Bagging,Random Forest,Boosting 简介 提升方法(boosting)是一种常用的统计学习方法&#xff0c;在分类问题中&#xff0c;它通过改变训练样本的权重&#xff0c;学习多个分…

matlab画x的1 3次方,如何用Matlab画出f(x)=f(x-1)+2的x次方*3的图像

如何用Matlab画出f(x)f(x-1)2的x次方*3的图像以下文字资料是由(历史新知网www.lishixinzhi.com)小编为大家搜集整理后发布的内容&#xff0c;让我们赶快一起来看一下吧&#xff01;如何用Matlab画出f(x)f(x-1)2的x次方*3的图像你要画的范围假设是0到10Fzeros(1,10);F(1)1;要有初…

wu** C语言注意点

1 函数的注释2.1 文档的结构2.2 头文件的结构 static, extern2.3 目录结构3 函数设计&#xff1a; 函数输出参数&#xff1a;返回正常值&#xff1b;return返回错误值。 函数中定义局部数组&#xff0c;char str[] "……"影响堆栈 内存泄露检测工具。4 …

机器学习算法总结--GBDT

参考如下 机器学习&#xff08;四&#xff09;— 从gbdt到xgboost机器学习常见算法个人总结&#xff08;面试用&#xff09;xgboost入门与实战&#xff08;原理篇&#xff09; 简介 GBDT是一个基于迭代累加的决策树算法&#xff0c;它通过构造一组弱的学习器&#xff08;树&a…

matlab画半球面,Matlab 绘制3D半球

R10;zreal(zeros(201,201));m0;n0;step 0.1;for x-R:step:Rm m 1;%xfor y-sqrt(R*R - x*x):step:sqrt(R*R - x*x)%yn int32(y / step) R / step 1;%nz(n, m) real(sqrt(R*R - x*x - y*y));endforendfor%zmesh(z);另一种方法(from octave)&#xff1a;function [xx, yy, …

机器学习算法总结--EM算法

参考自 《统计学习方法》机器学习常见算法个人总结&#xff08;面试用&#xff09;从最大似然到EM算法浅解&#xff08;EM算法&#xff09;The EM Algorithm 简介 EM算法&#xff0c;即期望极大算法&#xff0c;用于含有隐变量的概率模型的极大似然估计或极大后验概率估计&am…

流程平台:子表控件(二) - 属性、事件、方法

子表控件的元数据如下&#xff1a;属性、事件、方法&#xff1a; public class SheetSubTableSZ : WebControl, ISheetControl{// 分隔符public const char Separator ;;public const string SeqNoColumnName "序号";// 添加按钮public Button Add;// …

nginx php7 win,Win7配置Nginx+PHP7

NginxNginx有官方native build的32bit版本, 也有cygwin build的64bit版本, 出于稳定性的考虑, 还是选了官方的32bit.解压, 本例中使用的路径是 C:\Servers\nginx-1.9.12 , 创建两个bat, 用于启动和关闭nginx:start_nginx.bat1234echooffsetNGINX_HOMEC:\Servers\nginx-1.9.12st…

(转)Thrift在Windows及Linux平台下的安装和使用示例

转载自Thrift在Windows及Linux平台下的安装和使用示例 thrift介绍 Apache Thrift 是 Facebook 实现的一种高效的、支持多种编程语言的RPC(远程服务调用)框架。 本文主要目的是分别介绍在Windows及Linux平台下的Thrift安装步骤&#xff0c;以及实现一个简单的demo演示Thrif…

CPP第四版第四章:创建动态数组

数组类型的变量有三个重要限制&#xff1a; 数组长度固定不变 在编译时必须知道其长度 数组只在定义它的块语句内存在 每一个程序在执行时都占用一块可用的内存空间&#xff0c;用于存放动态分配的对象&#xff0c;此内存空间称为程序的自由存储区或堆…

matlab中数据变为nan,字符转化为数值型中出现NAN

我将字符型转化为数值型&#xff0c;然后画图&#xff0c;结果图形没有曲线&#xff0c;这是怎么回事&#xff1f;我用的函数是str2double和str2num都试了&#xff0c;都不行。我的程序如下,其中的E2(i)的值我用matlab计算了&#xff0c;为什么是这么庞大的一个数&#xff1f;这…

(转)在Windows上安装GPU版Tensorflow

转载自在Windows上安装GPU版Tensorflow。 1. 下载安装Anaconda 简单说就是下载 64位 python 3.5 版本的Anaconda https://www.continuum.io/downloads#windows 安装情况&#xff1a;新机&#xff0c;未装python。 注意 a. Windows只支持64位 python 3.5 https://www.ten…