import euclideanDistance from "./euclidean_distance.js";
import makeMatrix from "./make_matrix.js";
import sample from "./sample.js";
/**
* @typedef {Object} kMeansReturn
* @property {Array<number>} labels 数据点所属簇的标签数组
* @property {Array<Array<number>>} centroids 聚类中心坐标数组
*/
/**
* 执行k均值聚类算法
*
* @param {Array<Array<number>>} points 待聚类点的N维坐标数组
* @param {number} numCluster 需要创建的聚类数量
* @param {Function} randomSource 可选随机源,生成[0,1)区间的均匀分布值
* @return {kMeansReturn} 包含标签数组和聚类中心数组的对象
* @throws {Error} 当存在无关联数据点的聚类中心时抛出异常
*
* @example
* kMeansCluster([[0.0, 0.5], [1.0, 0.5]], 2); // => {labels: [0, 1], centroids: [[0.0, 0.5], [1.0 0.5]]}
*/
function kMeansCluster(points, numCluster, randomSource = Math.random) {
let oldCentroids = null;
let newCentroids = sample(points, numCluster, randomSource);
let labels = null;
let change = Number.MAX_VALUE;
while (change !== 0) {
labels = labelPoints(points, newCentroids);
oldCentroids = newCentroids;
newCentroids = calculateCentroids(points, labels, numCluster);
change = calculateChange(newCentroids, oldCentroids);
}
return {
labels: labels,
centroids: newCentroids
};
}
/**
* 根据当前聚类中心为数据点分配簇标签
*
* @private
* @param {Array<Array<number>>} points 数据点坐标数组
* @param {Array<Array<number>>} centroids 当前聚类中心坐标数组
* @return {Array<number>} 数据点对应的簇标签数组
*/
function labelPoints(points, centroids) {
return points.map((p) => {
let minDist = Number.MAX_VALUE;
let label = -1;
for (let i = 0; i < centroids.length; i++) {
const dist = euclideanDistance(p, centroids[i]);
if (dist < minDist) {
minDist = dist;
label = i;
}
}
return label;
});
}
/**
* 根据数据点标签计算新的聚类中心
*
* @private
* @param {Array<Array<number>>} points 数据点坐标数组
* @param {Array<number>} labels 数据点簇标签数组
* @param {number} numCluster 聚类总数
* @return {Array<Array<number>>} 计算得到的新聚类中心数组
* @throws {Error} 当存在无关联数据点的聚类中心时抛出异常
*/
function calculateCentroids(points, labels, numCluster) {
// 初始化累加器和计数器
const dimension = points[0].length;
const centroids = makeMatrix(numCluster, dimension);
const counts = Array(numCluster).fill(0);
// 累加各簇数据点坐标并计数
const numPoints = points.length;
for (let i = 0; i < numPoints; i++) {
const point = points[i];
const label = labels[i];
const current = centroids[label];
for (let j = 0; j < dimension; j++) {
current[j] += point[j];
}
counts[label] += 1;
}
// 计算均值并检查空簇
for (let i = 0; i < numCluster; i++) {
if (counts[i] === 0) {
throw new Error(`聚类中心${i}无关联数据点`);
}
const centroid = centroids[i];
for (let j = 0; j < dimension; j++) {
centroid[j] /= counts[i];
}
}
return centroids;
}
/**
* 计算新旧聚类中心的总变化量
*
* @private
* @param {Array<Array<number>>} left 新聚类中心数组
* @param {Array<Array<number>>} right 旧聚类中心数组
* @return {number} 聚类中心坐标变化总和
*/
function calculateChange(left, right) {
let total = 0;
for (let i = 0; i < left.length; i++) {
total += euclideanDistance(left[i], right[i]);
}
return total;
}
export default kMeansCluster;