一、 基本原理
Mean Shift是一种基于密度的非参数聚类算法,不需要预先指定簇的数量,而是通过寻找数据空间中密度最大的区域来自动确定聚类中心, 适合图像分割和目标跟踪等。
算法步骤
-
初始化:对每个数据点作为起点。
-
迭代:计算当前点的邻域内所有点的加权均值,将当前点移动到该均值位置。
-
终止:当移动距离小于阈值或达到最大迭代次数时停止。
-
聚类:合并收敛到同一位置的点为一个簇。
数学描述
核函数:通常使用高斯核,衡量数据点之间的权重:
其中 ℎ 是带宽(bandwidth
)
均值漂移向量:点 x的漂移方向为:
其中 N(x) 是 x 的邻域(由带宽决定)
二、特点
优点
-
无需指定簇数:自动发现数据中的聚类结构。
-
适应任意形状:可以识别非球形分布的簇。
-
鲁棒性:对噪声和异常值不敏感。
缺点
-
计算复杂度高:每轮迭代需要计算所有点的邻域关系,时间复杂度O(n2)。
-
带宽选择敏感:
bandwidth
对结果影响大,需谨慎选择。 -
不适合高维数据:维度灾难可能导致效果下降。
三、Python 实现
from sklearn.cluster import MeanShift
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt# 生成数据
X, _ = make_blobs(n_samples=500, centers=3, cluster_std=0.8, random_state=42)# Mean Shift 聚类
ms = MeanShift(bandwidth=1.5) # bandwidth是关键参数
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_# 可视化
plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis', alpha=0.5)
plt.scatter(cluster_centers[:, 0], cluster_centers[:, 1], c='red', marker='x', s=100)
plt.title("Mean Shift Clustering")
plt.show()
参数说明
-
bandwidth
:决定邻域大小的关键参数,使用estimate_bandwidth()
辅助确定。若值太小,会导致过多小簇;若太大,会合并所有数据为单一簇。可通过以下方法估计:from sklearn.cluster import estimate_bandwidth bandwidth = estimate_bandwidth(X, quantile=0.2) # quantile影响邻域范围
-
bin_seeding
:对大规模数据,先采样再聚类,bin_seeding设置
为True
,仅用离散化的种子点加速计算
应用实例:图像分割
from skimage import io, color
from sklearn.cluster import MeanShift# 加载图片
image = io.imread("example.jpg")
image_rgb = color.rgba2rgb(image) # 转换为RGB
h, w, _ = image_rgb.shape
X = image_rgb.reshape(-1, 3) # 将像素转换为特征向量# Mean Shift聚类
ms = MeanShift(bandwidth=0.1, bin_seeding=True)
ms.fit(X)
segmented = ms.labels_.reshape(h, w) # 还原为图像尺寸# 显示分割结果
plt.imshow(segmented, cmap='tab20')
plt.axis('off')
plt.show()