Spark数据挖掘-数据标准化

1 前言

特征数据标准化指的是对训练样本通过利用每一列的统计量将特征列转换为0均值单位方差的数据。 这是非常通用的数据预处理步骤。
例如:RBF核的支持向量机或者基于L1和L2正则化的线性模型在数据标准化之后效果会更好。
数据标准化能够改进优化过程中数据收敛的速度,也能防止一些方差过大的变量特征对模型训练 产生过大的影响。
如何对数据标准化呢?公式也非常简单:新的列 = (老的列每一个值 - 老的列平均值) / (老的列标准差)

2 数据准备

在标准化之前,Spark必须知道每一列的平均值,方差,具体怎么知道呢?
想法很简单,首先给 Spark的 StandardScaler 一批数据,这批数据以 org.apache.spark.mllib.feature.Vector 的形式提供给 StandardScaler。StandardScaler 对输入的数据进行 fit 即计算每一列的平均值,方差。 调度代码如下:

import org.apache.spark.SparkContext._
import org.apache.spark.mllib.feature.StandardScaler
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils

val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")

val scaler1 = new StandardScaler().fit(data.map(x => x.features))
val scaler2 = new StandardScaler(withMean = true, withStd = true).fit(data.map(x => x.features))

上面代码的本质是生成一个包含每一列均值和方差的 StandardScalarModel,具体解释一下 withMean 和 withStd 的含义:

  • withMean 如果值为true,那么将会对列中每个元素减去均值(否则不会减)
  • withStd 如果值为true,那么将会对列中每个元素除以标准差(否则不会除,这个值一般为 true,否则没有标准化没有意义) 所以上面两个参数都为 false 是没有意义的,模型什么都不会干,返回原来的值,这些将会在下面的代码中得到验证。

下面给出上面 fit 函数的源代码:

/**
  * 计算数据每一列的平均值标准差,将会用于之后的标准化.
  *
  * @param data The data used to compute the mean and variance to build the transformation model.
  * @return a StandardScalarModel
  */
 @Since("1.1.0")
 def fit(data: RDD[Vector]): StandardScalerModel = {
   // TODO: 如果 withMean 和 withStd 都为false,什么都不用干
   //计算基本统计
   val summary = data.treeAggregate(new MultivariateOnlineSummarizer)(
     (aggregator, data) => aggregator.add(data),
     (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
   //通过标准差,平均值得到模型
   new StandardScalerModel(
     Vectors.dense(summary.variance.toArray.map(v => math.sqrt(v))),
     summary.mean,
     withStd,
     withMean)
 }

从这里可以发现,如果你知道每一列的平均值和方差,直接通过 StandardScalarModel 构建模型就可以了,如下代码:

val scaler3 = new StandardScalerModel(scaler2.std, scaler2.mean)

3 数据标准化

准备工作做好了,下面真正标准化,调用代码也非常简单:

al data1 = data.map(x => (x.label, scaler1.transform(x.features)))

用模型对每一行 transform 就可以了,背后的原理也非常简单,代码如下:

// 因为 `shift` 只是在 `withMean` 为真的分支中才使用, 所以使用了
 // `lazy val`. 注意:这里不想在每一次 `transform` 都计算一遍 shift.
 private lazy val shift: Array[Double] = mean.toArray

 /**
  * Applies standardization transformation on a vector.
  *
  * @param vector Vector to be standardized.
  * @return Standardized vector. If the std of a column is zero, it will return default `0.0`
  *         for the column with zero std.
  */
 @Since("1.1.0")
 override def transform(vector: Vector): Vector = {
   require(mean.size == vector.size)
   if (withMean) {
     // By default, Scala generates Java methods for member variables. So every time when
     // the member variables are accessed, `invokespecial` will be called which is expensive.
     // This can be avoid by having a local reference of `shift`.
     val localShift = shift
     vector match {
       case DenseVector(vs) =>
         val values = vs.clone()
         val size = values.size
         if (withStd) {
           var i = 0
           while (i < size) {
             values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0
             i += 1
           }
         } else {
           var i = 0
           while (i < size) {
             values(i) -= localShift(i)
             i += 1
           }
         }
         Vectors.dense(values)
       case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
     }
   } else if (withStd) {
     vector match {
       case DenseVector(vs) =>
         val values = vs.clone()
         val size = values.size
         var i = 0
         while(i < size) {
           values(i) *= (if (std(i) != 0.0) 1.0 / std(i) else 0.0)
           i += 1
         }
         Vectors.dense(values)
       case SparseVector(size, indices, vs) =>
         // For sparse vector, the `index` array inside sparse vector object will not be changed,
         // so we can re-use it to save memory.
         val values = vs.clone()
         val nnz = values.size
         var i = 0
         while (i < nnz) {
           values(i) *= (if (std(indices(i)) != 0.0) 1.0 / std(indices(i)) else 0.0)
           i += 1
         }
         Vectors.sparse(size, indices, values)
       case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
     }
   } else {
     // Note that it's safe since we always assume that the data in RDD should be immutable.
     vector
   }
 }

标准化原理简单,代码也简单,但是作用不能小看。