一、SVM支持向量机武林故事
我最早接触SVM支持向量机的时候,是看到一篇博客,说的是武林的故事,但是现在我怎么也找不到了,凭借着印象,重述一下这段传说:
- 相传故事发生在古时候,咸亨酒店,热闹非凡
- 店长出了一道题,完成挑战的人可以迎娶小姐
- 只见,桌子上放着黑棋和白棋,挑战者需要寻找一条线将黑棋和白旗完全隔开
- 刚刚开始,棋子比较少,大侠轻松的完成了任务
- 但随着棋子的增多,大侠百思不得其解,最后怒拍桌子
- 黑棋和白棋都飞到了空中,大侠一个剑气,空气凝固了,黑棋和白棋正好被隔开了
二、SVM支持向量机算法
支持向量机SVM是分类器,目标是寻找一个“超平面”,将样本合理的分类。简单来说,就是找一个函数,最优切割所有样本数据,所以“超平面”可以定义为:
解这个拉格朗日最优化问题,就是支持向量机分类的原理 ~
三、SVM支持向量机实战
- 下载地址:github机器学习数据下载
- 第一列是标签,后面是特征
支持向量机SVM,长剑一挥,剑气形成了一道屏障,隔开了好人(0)和坏人(1),我训练的准确率有93%,还可以哟~
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkContext._
import org.apache.spark.SparkContext
import org.apache.spark.SparkConf
import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, NaiveBayes, NaiveBayesModel, SVMWithSGD}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.util.{KMeansDataGenerator, LinearDataGenerator, LogisticRegressionDataGenerator, MLUtils}
import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD}
//向量
import org.apache.spark.mllib.linalg.Vector
//向量集
import org.apache.spark.mllib.linalg.Vectors
//稀疏向量
import org.apache.spark.mllib.linalg.SparseVector
//稠密向量
import org.apache.spark.mllib.linalg.DenseVector
//实例
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
//矩阵
import org.apache.spark.mllib.linalg.{Matrix, Matrices}
//索引矩阵
import org.apache.spark.mllib.linalg.distributed.RowMatrix
//RDD
import org.apache.spark.rdd.RDD
object WordCount {
def main(args: Array[String]) {
// 构建Spark 对象
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
val conf = new SparkConf().setAppName("HACK-AILX10").setMaster("local")
val sc = new SparkContext(conf)
// 读取样本数据
val datapath = "C:\\study\\spark\\sample_libsvm_data.txt"
val data = MLUtils.loadLibSVMFile(sc,datapath)
val splits = data.randomSplit(Array(0.6,0.4),seed=1L)
val training = splits(0)
val testing = splits(1)
//新建支持向量机SVM,并设置训练参数
val model = SVMWithSGD.train(training,100)
//对样本进行测试
val prediction_and_label = testing.map {
p => (model.predict(p.features), p.label)
}
val print_predict = prediction_and_label.take(5)
println("预测结果" + "\t\t\t\t\t\t" + "标签")
for (i <- 0 to print_predict.length -1){
println(print_predict(i)._1 + "\t\t\t\t\t\t" + print_predict(i)._2)
}
//计算测试误差
val metrics = new MulticlassMetrics(prediction_and_label)
val accuracy = metrics.accuracy
println("准确率=" + accuracy)
// 模型保存
val modelpath = "C:\\study\\spark\\SVM"
model.save(sc,modelpath)
println("模型保存 ok")
}
}
本篇完~