import euclideanDistance from "./euclidean_distance.js";
import makeMatrix from "./make_matrix.js";
import max from "./max.js";
/**
* 计算聚类数据的[轮廓系数](https://en.wikipedia.org/wiki/Silhouette_(clustering))
*
* @param {Array<Array<number>>} points N维数据点的坐标数组
* @param {Array<number>} labels 数据点标签数组,长度必须与points一致,
* 且取值范围为[0..G-1],其中G为分组总数
* @return {Array<number>} 各数据点的轮廓系数值
*
* @example
* silhouette([[0.25], [0.75]], [0, 0]); // => [1.0, 1.0]
*/
function silhouette(points, labels) {
if (points.length !== labels.length) {
throw new Error("标签数量必须与数据点数量严格一致");
}
const groupings = createGroups(labels);
const distances = calculateAllDistances(points);
const result = [];
for (let i = 0; i < points.length; i++) {
let s = 0;
if (groupings[labels[i]].length > 1) {
const a = meanDistanceFromPointToGroup(
i,
groupings[labels[i]],
distances
);
const b = meanDistanceToNearestGroup(
i,
labels,
groupings,
distances
);
s = (b - a) / Math.max(a, b);
}
result.push(s);
}
return result;
}
/**
* 创建组ID到点ID的查找表
*
* @private
* @param {Array<number>} labels 数据点标签数组,长度必须与points一致,
* 且取值范围为[0..G-1],其中G为分组总数
* @return {Array<Array<number>>} 长度G的数组,每个元素为对应组内数据点索引的数组
*/
function createGroups(labels) {
const numGroups = 1 + max(labels);
const result = Array(numGroups);
for (let i = 0; i < labels.length; i++) {
const label = labels[i];
if (result[label] === undefined) {
result[label] = [];
}
result[label].push(i);
}
return result;
}
/**
* 创建全量点间距离查找表
*
* @private
* @param {Array<Array<number>>} points N维数据点的坐标数组
* @return {Array<Array<number>>} 对称方阵形式的点间距离矩阵(主对角线为零)
*/
function calculateAllDistances(points) {
const numPoints = points.length;
const result = makeMatrix(numPoints, numPoints);
for (let i = 0; i < numPoints; i++) {
for (let j = 0; j < i; j++) {
result[i][j] = euclideanDistance(points[i], points[j]);
result[j][i] = result[i][j];
}
}
return result;
}
/**
* 计算当前点到最近组(由最近邻点确定)的平均距离
*
* @private
* @param {number} which 当前点索引
* @param {Array<number>} labels 数据点标签数组
* @param {Array<Array<number>>} groupings 组结构数组,每个元素为对应组内数据点索引的数组
* @param {Array<Array<number>>} distances 对称方阵形式的点间距离矩阵
* @return {number} 当前点到最近组的平均距离
*/
function meanDistanceToNearestGroup(which, labels, groupings, distances) {
const label = labels[which];
let result = Number.MAX_VALUE;
for (let i = 0; i < groupings.length; i++) {
if (i !== label) {
const d = meanDistanceFromPointToGroup(
which,
groupings[i],
distances
);
if (d < result) {
result = d;
}
}
}
return result;
}
/**
* 计算点到指定组的平均距离(可包含自身所在组)
*
* @private
* @param {number} which 当前点索引
* @param {Array<number>} group 目标组的数据点索引数组
* @param {Array<Array<number>>} distances 对称方阵形式的点间距离矩阵
* @return {number} 当前点到目标组的平均距离
*/
function meanDistanceFromPointToGroup(which, group, distances) {
let total = 0;
for (let i = 0; i < group.length; i++) {
total += distances[which][group[i]];
}
return total / group.length;
}
export default silhouette;