前言:之前学过统计学习这门课,基本上是了解过主流的机器学习算法。但是一直没有自己从程序的角度去深入理解它们。现在准备阅读相关算法的实现源码来进一步理解这些算法。
参考资料:python《机器学习实战》
C++ Shark开源库源码
一.KNN算法原理
KNN算法可以视为是最简单的分类算法。它是一种Lazy learning,并不需要训练出来实际的数学模型,甚至也可以认为这种算法不需要训练的过程。假设我们的训练集里面有x1,x2,….,xn一共n个m维的训练样本,每个样本都有对应的标签y1,y2…,yn。现在给定测试向量t,t也是m维的向量,我们需要做的就是判断t的类别。
KNN算法首先定义一种距离度量标准来度量t和xi的远近程度,最简单的度量标准就是欧氏距离。接下来,需要找出训练集中与t距离最近的k个训练样本,这k个样本各自所属的类别也是已知的。最后,我们选取这k个样本中所属类别最多的类别作为t的类别。KNN算法基于非常朴素的事实:如果两个样本非常相似,那么它们所属的类别也应该基本相同。
KNN算法的优点:简单,精度高
KNN算法的缺点:计算复杂度高,空间复杂度高,当训练样本数目很大的时候难以实现。
二.KNN算法的C++实现(shark库源码分析)
测试源码:
#include <Rng/GlobalRng.h>
#include <ReClaM/ArtificialDistributions.h>
#include <ReClaM/Dataset.h>
#include <ReClaM/KernelNearestNeighbor.h>
#include <ReClaM/ClassificationError.h>
#include <stdio.h>
#include <iostream>
using namespace std;
int main()
{
Rng::seed(10);//初始化随机数种子
double gamma = 0.5;
RBFKernel k(gamma); //定义RBF核,使用exp(-parameter(0) * dist2)来归一化距离
cout << endl;
cout << "*** kernel nearest neighbor classifier ***" << endl;
cout << endl;
// create the xor problem with uniformly distributed examples
unsigned int n = 3;
cout << "Generating 100 training and 10000 test examples ..." << flush;
Chessboard chess(2, 2);
Dataset dataset;
dataset.CreateFromSource(chess, 100, 10000);
const Array<double>& x = dataset.getTrainingData();
const Array<double>& y = dataset.getTrainingTarget();
cout << " done." << endl;
// create the kernel mean classifier
cout << "Creating the 3-nearest-neighbor classifier ..." << flush;
KernelNearestNeighbor knn(x, y, &k, n);
cout << " done." << endl;
// estimate the accuracy on the test set
cout << "Testing ..." << flush;
ClassificationError ce;
double acc = 1.0 - ce.error(knn, dataset.getTestData(), dataset.getTestTarget());//执行实际的分类
cout << " done." << endl;
cout << "Estimated accuracy: " << 100.0 * acc << "%" << endl << endl;
// lines below are for self-testing this example, please ignore
if (acc >= 0.92) exit(EXIT_SUCCESS);
else exit(EXIT_FAILURE);
}
寻找k个近邻的核心算法程序:
double KernelNearestNeighbor::classify(Array<double> pattern)
{
int i, j, m, u, c, l = training_input.dim(0);
double dist2, best;
double norm2 = kernel->eval(pattern, pattern);
std::vector<int> used; // sorted list of neighbors
for (i = 0; i < numberOfNeighbors; i++)//for循环,每次寻找一个最近邻
{
// find the nearest neighbor not already in the list
best = 1e100;
m = 0;
for (j = 0; j < i; j++)
{
u = used[j];
for (; m < u; m++)
{
dist2 = diag(m) + norm2 - 2.0 * kernel->eval(training_input[m], pattern);
if (dist2 < best)
{
best = dist2;
c = m;
}
}
m++;
}
for (; m < l; m++)
{
dist2 = diag(m) + norm2 - 2.0 * kernel->eval(training_input[m], pattern);
if (dist2 < best)
{
best = dist2;
c = m;
}
}
// insert the nearest neighbor into the sorted list
for (j = 0; j < i; j++) if (used[j] >= c) break;
if (j == i) used.push_back(c);
else used.insert(used.begin() + j, c);
}
double mean = 0.0;
for (i = 0; i < numberOfNeighbors; i++) mean += training_target(used[i], 0);
return (mean > 0.0) ? 1.0 : -1.0;
}
shark库实现的寻找K近邻算法的复杂度是O(k*n),每次需要遍历整个训练集,第一次寻找距离最小的样本,第二次寻找距离第二小的样本,直至寻找出距离第K小的样本。想起之前写的博客《【July程序员编程艺术】之最小的k个数问题》,可以采用最大堆的数据结构,那样可以把复杂度优化到O(n*log k)。
三.KNN算法的python实现(机器学习实战源码)
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0] //得到数据集大小
diffMat = tile(inX, (dataSetSize,1)) - dataSet//将输入向量扩展成矩阵,然后减去训练矩阵,得到差值矩阵
sqDiffMat = diffMat**2//差值矩阵每个元素求平方
sqDistances = sqDiffMat.sum(axis=1)//每一行求和,即求解每个训练样本与测试向量的距离
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort() //排序
classCount={}
for i in range(k)://进行投票
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)//对投票结果排序
return sortedClassCount[0][0]
python实现使用了numpy库,因此可以直接使用矩阵进行运算,相比于C++的实现要方便很多。python实现中还对训练样本的特征向量值做了一下归一化,这也是很有意义的操作:
def autoNorm(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m,1))
normDataSet = normDataSet/tile(ranges, (m,1)) #element wise divide
return normDataSet, ranges, minVals