perceptron.js

/**
 * 单层[感知器分类器](http://en.wikipedia.org/wiki/Perceptron),用于对数值数组进行二分类(0/1)
 * @class
 * @example
 * // 创建分类器实例
 * var p = new PerceptronModel();
 * // 使用对角线决策边界数据进行训练
 * for (var i = 0; i < 5; i++) {
 *     p.train([1, 1], 1);
 *     p.train([0, 1], 0);
 *     p.train([1, 0], 0);
 *     p.train([0, 0], 0);
 * }
 * p.predict([0, 0]); // 0
 * p.predict([0, 1]); // 0
 * p.predict([1, 0]); // 0
 * p.predict([1, 1]); // 1
 */
class PerceptronModel {
    /*:: bias: number */
    /*:: weights: Array<number> */
    constructor() {
        // 模型权重系数数组(训练后生成)
        this.weights = [];
        // 偏置项(始终与1相乘的权重)
        this.bias = 0;
    }

    /**
     * **预测**:使用特征向量与权重数组的线性组合进行二分类预测
     *
     * @param {Array<number>} features 数值型特征数组
     * @returns {number} 当加权和超过0时返回1,否则返回0
     */
    predict(features) {
        // 仅当输入特征维度与模型参数匹配时执行预测
        if (features.length !== this.weights.length) {
            return null;
        }

        // 计算特征与权重的点积并加上偏置项
        let score = 0;
        for (let i = 0; i < this.weights.length; i++) {
            score += this.weights[i] * features[i];
        }
        score += this.bias;

        // 应用阶跃函数进行二分类
        return score > 0 ? 1 : 0;
    }

    /**
     * **训练**:使用特征-标签对更新模型参数(在线学习算法)
     *
     * @param {Array<number>} features 数值型特征数组
     * @param {number} label 二分类标签(0或1)
     * @returns {PerceptronModel} 返回当前对象以支持链式调用
     */
    train(features, label) {
        // 标签有效性验证
        if (label !== 0 && label !== 1) {
            throw new Error("无效标签:仅接受0或1的二分类标签");
        }

        // 特征维度变化时重新初始化模型参数
        if (features.length !== this.weights.length) {
            this.weights = [...features]; // 使用展开运算符进行数组复制
            this.bias = 1;
        }

        // 执行预测并计算误差
        const prediction = this.predict(features);
        if (prediction !== null && prediction !== label) {
            // 应用梯度下降算法更新参数
            const error = label - prediction;
            for (let i = 0; i < this.weights.length; i++) {
                this.weights[i] += error * features[i]; // 特征权重更新
            }
            this.bias += error; // 偏置项更新
        }
        return this;
    }
}

export default PerceptronModel;