k-means算法又称k-均值算法,是机器学习聚类算法中的一种,是一种基于形心的划分方法,其中每个簇的中心都用簇中所有对象的均值来表示。其思想如下:
输入:
- k:簇的数目;
- D:包含n个对象的数据集。
方法:
- 从D中随机选择几个对象作为起始质心;
- 对每个质心,计算每个数据到各个质心的距离,并把这些点分配到离该质心最短的距离的簇;
- 对每个簇,计算簇中所有点的均值并将此均值作为新的质心;
- 将数据点按照新的中心重新聚类;
- 重复【步骤3】,直到质心不再发生变化(新的质心和原来的质心相等);
- 输出聚类结果。
木羊的k-means算法实现包括5各类。其中,DBConnection.java用于连接数据库,SelectData.java用于从数据库里读取数据,Point.java存放点对象模型,ManagePoint.java是对点的操作,Kmeans.java是算法的核心思想及主函数入口。以下分别给出各个类的详细代码:
DBConnection.java
数据集获取,在机器学习数据集获取官方网站UCI中点击打开链接,木羊已经把该数据集从txt文档中插入到数据库,并去除了最后一列(花类别)。读者若不熟悉数据库的读写,请百度。若木羊有时间,会在后面的博文中补充把txt文档内容读到数据库中的内容。
<span style="font-size:18px;">package db;import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;/*** * 数据库连接类* */
public class DBConnection {public static final String driver = "com.mysql.jdbc.Driver";public static final String url = "jdbc:mysql://localhost:3306/mydb";public static final String user = "root";public static final String pwd = "123";public static Connection dBConnection() {Connection con = null;try {// 加载mysql驱动器Class.forName(driver);// 建立数据库连接con = DriverManager.getConnection(url, user, pwd);} catch (ClassNotFoundException e) {// TODO Auto-generated catch blockSystem.out.println("加载驱动器失败");e.printStackTrace();} catch (SQLException e) {// TODO Auto-generated catch blockSystem.out.println("注册驱动器失败");e.printStackTrace();}return con;}
}</span>
数据库中的数据字段如下(共有150条数据):
SelectData.java
package dao;import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;import model.Point;
import db.DBConnection;/*** * 取出数据* * @return pointList* */
public class SelectData {public static final String SELECT = "select* from iris_Kmeans";public ArrayList<Point> getPoints() throws SQLException {ArrayList<Point> pointsList = new ArrayList<Point>();Connection con = DBConnection.dBConnection();ResultSet rs;// 创建一个PreparedStatement对象PreparedStatement pstmt = con.prepareStatement(SELECT);rs = pstmt.executeQuery();while (rs.next()) {Point point = new Point();point.setX(rs.getDouble(2));point.setY(rs.getDouble(3));point.setZ(rs.getDouble(4));point.setW(rs.getDouble(5));pointsList.add(point);}System.out.println("数据集: " + pointsList);pstmt.close();rs.close();con.close();return pointsList;}
}
Point.java
此处要注意重写equal和hashcode方法以便后面质心的比较。
package model;public class Point {private double x;private double y;private double z;private double w;public double getX() {return x;}public void setX(double x) {this.x = x;}public double getY() {return y;}public void setY(double y) {this.y = y;}public double getZ() {return z;}public void setZ(double z) {this.z = z;}public double getW() {return w;}public void setW(double w) {this.w = w;}public Point() {}public Point(double x, double y, double z, double w) {super();this.x = x;this.y = y;this.z = z;this.w = w;}@Overridepublic String toString() {return "Point [ x=" + x + ", y=" + y + ", z=" + z + ", w=" + w + "]";}@Overridepublic boolean equals(Object obj) {Point point = (Point) obj;if (this.getX() == point.getX() && this.getY() == point.getY()&& this.getZ() == point.getZ() && this.getW() == point.getW()) {return true;}return false;}@Overridepublic int hashCode() {return (int) (x + y + z + w);}
}
该类包含了3个方法,分别用于计算两个点的欧氏距离,比较前后两个质心是否相同,更新质心。
package util;import java.util.ArrayList;
import java.util.Map;import model.Point;public class ManagePoint {/*** * 计算两点之间的距离* * @param p* 第一个点* @param q* 第二个点* @return distance* */public double getDistance(Point p, Point q) {double dx = p.getX() - q.getX();double dy = p.getY() - q.getY();double dz = p.getZ() - q.getZ();double dw = p.getW() - q.getW();double distance = Math.sqrt(dx * dx + dy * dy + dz * dz + dw * dw);return distance;}/*** 判断前后两个质心是否相同* * @param nowCenterCluster* 现在的质心* @param lastCenterCluster* 上一次的质心* @return boolean* */public boolean isEqual(Map<Point, ArrayList<Point>> lastCenterCluster,Map<Point, ArrayList<Point>> nowCenterCluster) {boolean contain = false;if (lastCenterCluster == null)return false;else {for (Point point : nowCenterCluster.keySet()) {contain = lastCenterCluster.containsKey(point);}if (contain)return true;}return false;}/*** * 计算新的质心* * @param value* map中的值,存放簇中的所有点* @return point* */public Point getNewCenter(ArrayList<Point> value) {double sumX = 0, sumY = 0, sumZ = 0, sumW = 0;for (Point point : value) {sumX += point.getX();sumY += point.getY();sumZ += point.getZ();sumW += point.getW();}System.out.println("新的质心: (" + sumX / value.size() + "," + sumY/ value.size() + "," + sumZ / value.size() + "," + sumW/ value.size() + ")");Point point = new Point();point.setX(sumX / value.size());point.setY(sumY / value.size());point.setZ(sumZ / value.size());point.setW(sumW / value.size());return point;}
}
Kmeans.java
木羊把簇存在hashmap里,其中key存放该簇的质心,value存放该簇的所有点。特别注意的是,为了使最终聚类相对较理想,随机选择的三个初始质心应该在[0-50)、[50-100)、[100-150]三个区间内。
package util;import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;import model.Point;
import dao.SelectData;public class Kmeans {public Map<Point, ArrayList<Point>> executeKmeans(int k) {ArrayList<Point> dataList = new ArrayList<Point>();// 存放原始数据Map<Point, ArrayList<Point>> nowCenterClusterMap = new HashMap<Point, ArrayList<Point>>();// 当前质心及其簇内的点Map<Point, ArrayList<Point>> lastCenterClusterMap = null;// 上一个质心及其簇内的点try {dataList = new SelectData().getPoints();// 随机创建K个点作为起始质心Random rd = new Random();int[] initIndex = { 50, 50, 50 };int[] tempIndex = { 0, 50, 100 };System.out.println("起始质心下标: ");for (int i = 0; i < k; i++) {int index = rd.nextInt(initIndex[i]) + tempIndex[i];System.out.println("第" + (i + 1) + "个 : " + index);nowCenterClusterMap.put(dataList.get(index),new ArrayList<Point>());}// 输出起始质心System.out.println("起始质心: ");for (Point point : nowCenterClusterMap.keySet())System.out.println("key: " + point);// 将数据点point加入配到离其最近的map的value中ManagePoint managePoint = new ManagePoint();while (true) {for (Point point : dataList) {double shortestDistance = Double.MAX_VALUE;// 初始化最短距离为Double的最大值Point key = null;for (Entry<Point, ArrayList<Point>> entry : nowCenterClusterMap.entrySet()) {// 计算质心与各点间的距离double distance = managePoint.getDistance(entry.getKey(), point);if (distance < shortestDistance) {shortestDistance = distance;key = entry.getKey();}}nowCenterClusterMap.get(key).add(point);}// 如果新的质心与上次的质心相等,则退出整个循环if (managePoint.isEqual(lastCenterClusterMap,nowCenterClusterMap)) {System.out.println("相等了。");break;}// 更新质心lastCenterClusterMap = nowCenterClusterMap;nowCenterClusterMap = new HashMap<Point, ArrayList<Point>>();System.out.println("------------------------------------------------------------------");for (Entry<Point, ArrayList<Point>> entry : lastCenterClusterMap.entrySet()) {nowCenterClusterMap.put(managePoint.getNewCenter(entry.getValue()),new ArrayList<Point>());}}} catch (SQLException e) {// TODO Auto-generated catch blockSystem.out.println("数据库操作失败");e.printStackTrace();}return nowCenterClusterMap;}public static void main(String[] args) {int K = 3;// 分为三个类Map<Point, ArrayList<Point>> result = new Kmeans().executeKmeans(K);// 输出分类System.out.println("===========聚类结果: ============");for (Entry<Point, ArrayList<Point>> entry : result.entrySet()) {System.out.println("\n" + "稳定的质心: " + entry.getKey());System.out.println("该簇的大小: " + entry.getValue().size());System.out.println("簇里的点:" + entry.getValue());}}
}
以上代码均从MyEclipse上复制粘贴而来,亲测可运行,结果如下:
经测试,无论初始质心被随机选择成哪3个,最终稳定的质心都不变。
(欢迎讨论。代码尚有不完善之处,请多多指教。转载请注明出处。)