C语言如何实现KNN算法?

99ANYc3cd6
预计阅读时长 29 分钟
位置: 首页 C语言 正文
  1. KNN 算法核心思想
  2. 实现 KNN 的 C 语言步骤
  3. 完整的 C 语言代码示例
  4. 代码解析与关键点
  5. 如何扩展与优化

KNN 算法核心思想

KNN 的工作原理非常直观,可以概括为“物以类聚,人以群分”。

c c 语言 knn 算法 c语言
(图片来源网络,侵删)

对于一个给定的未知数据点,要判断它的类别,KNN 算法会执行以下步骤:

  1. 计算距离:计算这个未知数据点与训练集中所有已知类别数据点的距离,常用的距离度量方法是欧几里得距离
  2. 寻找邻居:根据计算出的距离,对所有训练数据点进行排序,然后选择距离最近的 k 个数据点,这 k 个点就是未知点的“k 个最近邻”。
  3. 投票决定:统计这 k 个最近邻中,各个类别出现的次数。
  4. 确定类别:将未知数据点归为出现次数最多的那个类别。

关键参数

  • k:选择邻居的数量,k 值的选择对结果影响很大,k 值太小容易受噪声影响,k 值太大可能导致分类边界模糊。

实现 KNN 的 C 语言步骤

在 C 语言中,我们需要手动管理内存和数据结构,以下是实现 KNN 的具体步骤:

  1. 定义数据结构

    c c 语言 knn 算法 c语言
    (图片来源网络,侵删)
    • 需要一个结构体来表示一个数据点,包含其特征值和标签(类别)。
    • 需要一个结构体来存储一个邻居的信息,包括数据点的索引、与目标点的距离和它的标签。
  2. 读取数据

    从文件或硬编码中加载训练数据,为了方便,我们通常将数据存储在一个文本文件中,每行代表一个数据点,最后一列是标签,前面是特征值。

  3. 计算距离

    • 编写一个函数,计算两个数据点之间的欧几里得距离,对于 n 维特征,距离公式为: distance = sqrt((x1-y1)^2 + (x2-y2)^2 + ... + (xn-yn)^2)
  4. 寻找 k 个最近邻

    c c 语言 knn 算法 c语言
    (图片来源网络,侵删)
    • 遍历整个训练数据集,计算每个点到目标点的距离。
    • 使用一个数组来维护当前找到的 k 个最近邻,在每次计算一个新点的距离后,与数组中距离最远的点进行比较,如果更近,则替换它,这是一种高效的实现方式,无需对整个数据集进行完整排序。
  5. 投票并预测

    • 找到 k 个最近邻后,统计它们的标签。
    • 找出出现次数最多的标签,这就是预测结果。
  6. 主函数

    整合所有步骤,定义训练数据、k 值和待预测的数据点,然后调用上述函数进行预测并输出结果。


完整的 C 语言代码示例

下面是一个完整的、可运行的 C 语言 KNN 分类器示例,我们使用一个简单的二维数据集来进行演示。

数据集 (data.txt)

假设我们有以下数据,前两列是特征(如身高、体重),最后一列是类别(0 或 1)。

0 1.0 0
1.1 1.0 0
1.0 1.1 0
5.0 5.0 1
5.1 5.1 1
5.0 5.0 1

C 语言代码 (knn.c)

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
// 定义数据点的结构体
typedef struct {
    double *features; // 特征数组,[身高, 体重, ...]
    int label;        // 标签 (类别)
} DataPoint;
// 定义邻居的结构体,用于存储距离和索引
typedef struct {
    int index;        // 在训练集中的索引
    double distance;  // 到目标点的距离
    int label;        // 该邻居的标签
} Neighbor;
// 函数声明
double calculate_distance(DataPoint a, DataPoint b);
void find_k_nearest_neighbors(DataPoint *train_data, int train_size, DataPoint target, int k, Neighbor *neighbors);
int predict_label(Neighbor *neighbors, int k);
int main() {
    // --- 1. 准备数据 ---
    // 简单起见,我们直接在代码中硬编码数据
    // 在实际应用中,你可能会从文件中读取这些数据
    int num_features = 2; // 每个数据点有2个特征
    int num_train_samples = 6;
    DataPoint train_data[6];
    // 分配内存并初始化训练数据
    for (int i = 0; i < num_train_samples; i++) {
        train_data[i].features = (double *)malloc(num_features * sizeof(double));
    }
    // 填充训练数据 (来自 data.txt)
    train_data[0].features[0] = 1.0; train_data[0].features[1] = 1.0; train_data[0].label = 0;
    train_data[1].features[0] = 1.1; train_data[1].features[1] = 1.0; train_data[1].label = 0;
    train_data[2].features[0] = 1.0; train_data[2].features[1] = 1.1; train_data[2].label = 0;
    train_data[3].features[0] = 5.0; train_data[3].features[1] = 5.0; train_data[3].label = 1;
    train_data[4].features[0] = 5.1; train_data[4].features[1] = 5.1; train_data[4].label = 1;
    train_data[5].features[0] = 5.0; train_data[5].features[1] = 5.0; train_data[5].label = 1;
    // --- 2. 定义待预测的数据点 ---
    DataPoint target_point;
    target_point.features = (double *)malloc(num_features * sizeof(double));
    target_point.features[0] = 0.9; // 特征1
    target_point.features[1] = 1.2; // 特征2
    // 目标点的标签是未知的,需要我们去预测
    // --- 3. 设置 K 值 ---
    int k = 3;
    // --- 4. 执行 KNN 算法 ---
    // 创建一个数组来存储 k 个最近邻
    Neighbor k_neighbors[k];
    // 寻找 k 个最近邻
    find_k_nearest_neighbors(train_data, num_train_samples, target_point, k, k_neighbors);
    // 预测标签
    int predicted_label = predict_label(k_neighbors, k);
    // --- 5. 输出结果 ---
    printf("目标点: (%.2f, %.2f)\n", target_point.features[0], target_point.features[1]);
    printf("K 值: %d\n", k);
    printf("K 个最近邻及其距离和标签:\n");
    for (int i = 0; i < k; i++) {
        printf("  邻居 %d: 索引=%d, 距离=%.4f, 标签=%d\n",
               i + 1, k_neighbors[i].index, k_neighbors[i].distance, k_neighbors[i].label);
    }
    printf("\n预测结果: 目标点的类别是 %d\n", predicted_label);
    // --- 6. 释放内存 ---
    for (int i = 0; i < num_train_samples; i++) {
        free(train_data[i].features);
    }
    free(target_point.features);
    return 0;
}
/**
 * @brief 计算两个数据点之间的欧几里得距离
 */
double calculate_distance(DataPoint a, DataPoint b) {
    double sum = 0.0;
    for (int i = 0; i < 2; i++) { // 假设是二维特征
        sum += pow(a.features[i] - b.features[i], 2);
    }
    return sqrt(sum);
}
/**
 * @brief 寻找 k 个最近邻
 * @param train_data 训练数据集
 * @param train_size 训练数据集大小
 * @param target 目标数据点
 * @param k k值
 * @param neighbors 用于存储结果的 k 个邻居数组
 */
void find_k_nearest_neighbors(DataPoint *train_data, int train_size, DataPoint target, int k, Neighbor *neighbors) {
    // 1. 初始化邻居数组,用第一个 k 个点填充
    for (int i = 0; i < k; i++) {
        neighbors[i].index = i;
        neighbors[i].distance = calculate_distance(target, train_data[i]);
        neighbors[i].label = train_data[i].label;
    }
    // 2. 对初始的 k 个邻居按距离从大到小排序 (冒泡排序,简单实现)
    for (int i = 0; i < k - 1; i++) {
        for (int j = 0; j < k - i - 1; j++) {
            if (neighbors[j].distance < neighbors[j + 1].distance) {
                Neighbor temp = neighbors[j];
                neighbors[j] = neighbors[j + 1];
                neighbors[j + 1] = temp;
            }
        }
    }
    // 3. 遍历剩余的训练数据
    for (int i = k; i < train_size; i++) {
        double current_distance = calculate_distance(target, train_data[i]);
        // 如果当前点比邻居数组中距离最远的点(即第一个点)更近
        if (current_distance < neighbors[0].distance) {
            // 用当前点替换最远的邻居
            neighbors[0].index = i;
            neighbors[0].distance = current_distance;
            neighbors[0].label = train_data[i].label;
            // 重新对这 k 个邻居进行排序,确保第一个始终是最大的
            for (int j = 0; j < k - 1; j++) {
                if (neighbors[j].distance < neighbors[j + 1].distance) {
                    Neighbor temp = neighbors[j];
                    neighbors[j] = neighbors[j + 1];
                    neighbors[j + 1] = temp;
                }
            }
        }
    }
}
/**
 * @brief 对 k 个最近邻进行投票,预测标签
 */
int predict_label(Neighbor *neighbors, int k) {
    int count0 = 0, count1 = 0;
    for (int i = 0; i < k; i++) {
        if (neighbors[i].label == 0) {
            count0++;
        } else {
            count1++;
        }
    }
    // 返回票数最多的标签
    return (count0 > count1) ? 0 : 1;
}

代码解析与关键点

  1. DataPointNeighbor 结构体

    • DataPoint 封装了机器学习中的样本,包括它的特征和它所属的类别。
    • Neighbor 是一个辅助结构体,在寻找最近邻的过程中,我们不仅需要知道邻居是谁,还需要知道它的距离和标签,这样最后投票时才方便。
  2. calculate_distance 函数

    • 核心是欧几里得距离公式,这里我们假设特征是二维的(for (int i = 0; i < 2; i++)),如果要处理更高维度的数据,只需要修改这个循环的上限即可。
  3. find_k_nearest_neighbors 函数 (核心逻辑)

    • 初始化:我们用一个简单的循环,将前 k 个训练点作为初始的“候选邻居”。
    • 排序:我们对这 k 个候选邻居进行降序排序(从大到小),这样,neighbors[0] 始终是这 k 个点中距离最远的,这一步是为了方便后续的比较和替换。
    • 遍历与替换:我们从第 k+1 个训练点开始,计算它到目标点的距离。
      • 如果这个新点的距离小于 neighbors[0] 的距离(即比当前候选邻居中最远的那个还要近),那么我们就用这个新点替换掉 neighbors[0]
      • 替换后,neighbors[0] 又成了新的“最远点”,所以我们需要重新对这 k 个点进行排序,以确保 neighbors[0] 仍然是最大的。
    • 这个方法避免了每次都要对整个数据集进行排序,时间复杂度从 O(N log N) 降低到了 O(N*k),对于大数据集和较小的 k 值来说更高效。
  4. predict_label 函数

    • 逻辑非常简单,就是一个“多数表决”,遍历 k 个邻居,统计每个类别的票数,返回票数最多的类别。
  5. 内存管理

    • C 语言中,动态分配的内存在使用完毕后必须手动释放,我们在 main 函数开头使用 malloc 为每个数据点的 features 数组分配了内存,在函数结尾使用 free 将其释放,这是一个非常重要的好习惯,可以防止内存泄漏。

如何扩展与优化

这个示例是 KNN 的一个基础实现,你可以从以下几个方面进行扩展和优化:

  1. 从文件读取数据

    • 使用 fopen, fscanf, fgets 等函数从 data.txt 这样的文件中读取数据,而不是硬编码,这会使你的程序更加灵活和实用。
  2. 支持任意维度特征

    修改代码,使其能够处理任意数量的特征,你可以在读取数据时确定特征的维度,或者通过命令行参数传入。

  3. 更高效的排序算法

    • find_k_nearest_neighbors 函数中,我们使用了简单的冒泡排序来维护 k 个邻居,当 k 值较大时,可以考虑使用更高效的排序算法,如快速排序或堆排序,或者使用 C++ 标准库中的 std::sort(如果你用 C++ 实现)。
  4. 归一化/标准化

    • KNN 对数据的尺度非常敏感,如果一个特征的数值范围远大于另一个特征(年龄在 0-100,而收入在 10000-1000000),那么距离计算将主要由收入这个特征主导,为了避免这个问题,在使用 KNN 之前,通常需要对数据进行归一化(如 Min-Max Scaling)或标准化(如 Z-score Normalization)。
  5. 使用更高效的数据结构

    • 对于非常大的数据集,线性搜索(遍历所有点)会非常慢,可以考虑使用KD-TreeBall Tree 等空间划分数据结构来加速近邻搜索,这些数据结构可以将搜索的复杂度从 O(N) 降低到 O(log N)。

希望这个详细的教程能帮助你理解并掌握如何在 C 语言中实现 KNN 算法!

-- 展开阅读全文 --
头像
PHP模板如何安装到织梦根目录?
« 上一篇 2025-12-11
织梦CMS如何让已阅读文章显示已阅读?
下一篇 » 2025-12-11
取消
微信二维码
支付宝二维码

目录[+]