对同一像素点值的像素点归为一类,通过平均值进行取代,从而将图像进行压缩并且保证图像尽可能不失真,关键信息仍保留。
from PIL import Image
import numpy as np
from sklearn.cluster import KMeans
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3Ddef restore_image(cb, cluster, shape):row, col, dummy = shapeimage = np.empty((row, col, 3))index = 0for r in range(row):for c in range(col):image[r, c] = cb[cluster[index]]index += 1return imagedef show_scatter(a):N = 10print('原始数据:\n', a)density, edges = np.histogramdd(a, bins=[N,N,N], range=[(0,1), (0,1), (0,1)])density /= density.max()x = y = z = np.arange(N)d = np.meshgrid(x, y, z)fig = plt.figure(1, facecolor='w')ax = fig.add_subplot(111, projection='3d')ax.scatter(d[1], d[0], d[2], c='r', s=100*density, marker='o', depthshade=True)ax.set_xlabel(u'红色分量')ax.set_ylabel(u'绿色分量')ax.set_zlabel(u'蓝色分量')plt.title(u'图像颜色三维频数分布', fontsize=20)plt.figure(2, facecolor='w')den = density[density > 0]den = np.sort(den)[::-1]t = np.arange(len(den))plt.plot(t, den, 'r-', t, den, 'go', lw=2)plt.title(u'图像颜色频数分布', fontsize=18)plt.grid(True)plt.show()if __name__ == '__main__':matplotlib.rcParams['font.sans-serif'] = [u'SimHei']matplotlib.rcParams['axes.unicode_minus'] = Falsenum_vq = 256 #256个像素,最后想降维到256个维度im = Image.open('Lena.png') # flower2.png(200)/lena.png(50) 读取图片image = np.array(im).astype(np.float) / 255 # 将图像数据转化为array类型,方便后续的操作image = image[:, :, :3]#所有行、列、前三个维度,因为png有四个属性,RGBα,alpha为透明度,不需要,只用前三个维度信息即可image_v = image.reshape((-1, 3))#拉伸像素点,行不关心,列为3列===转换为二维数据,每一列均为一个维度的全部数据,RGB变成了3列,每行为一个像素点model = KMeans(num_vq)#通过Kmeans对每一行进行处理,也就是每个像素点进行处理,对每个像素点进行分类,像素点相似度的归类;创建聚类对象#每类都有一个中心像素点,将该中心像素点代替这一类,这里手动传入的是分成256个类别。不管像素点位置,只考虑相似与否是否为一类show_scatter(image_v)#画图N = image_v.shape[0] # 图像像素总数# 选择足够多的样本(如1000个),计算聚类中心idx = np.random.randint(0, N, size=1000)#从图像中随机选取1000个像素点image_sample = image_v[idx]model.fit(image_sample)#将这1000个像素点去训练模型,聚类结果,从1000个像素点中找到最重要的256个像素点作为中心点c = model.predict(image_v) # 将图像全部的像素点进行预测,看看图像中的所有像素点离这256个簇哪一个最近,把图像的所有像素点进行分类print('聚类结果:\n', c)print('聚类中心:\n', model.cluster_centers_)plt.figure(figsize=(15, 8), facecolor='w')plt.subplot(121)plt.axis('off')plt.title(u'原始图片', fontsize=18)plt.imshow(image)#plt.savefig('1.png')plt.subplot(122)vq_image = restore_image(model.cluster_centers_, c, image.shape)#聚类中心点、聚类结果、模型图像的形状 作为参数进行恢复图像plt.axis('off')plt.title(u'矢量量化后图片:%d色' % num_vq, fontsize=20)plt.imshow(vq_image)#plt.savefig('2.png')plt.tight_layout()plt.show()