sqlserver支持for xml path()语法,将返回结果嵌套在指定的xml标签中。项目组之前在spark2.0上实现了该功能。迁移到2.3时,由于原生spark修改较多,出现了很大的兼容问题。我的工作就是让这个函数重新运作起来。菜鸟真的被折磨的很痛苦,所幸还是成功解决了问题。

1. 语法说明

关于sqlserver中 for xml path的语法,大致就是将指定字段和连接的字符串包裹在xml标签中并返回,支持指定节点名。项目组在spark中的实现叫做group_xmlpath(),暂不支持指定标签名。

@ExpressionDescription(
  usage = "_FUNC_(expr) - Concat a list of elements.in a group.")

2. 函数实现

这个函底层有聚集实现。因此是在Collect.scala中实现。仿照 Collect_list 进行实现,spark2.3对上层接口进行了重构,增加了TypedImperativeAggregate,将很多方法都定义为final,使得之前的自定义代码都无法使用。

spark sql group 别名 spark sql server_ide

因此,为了可以定制实现,将final方法都放出来,以便重载。

TypedImperativeAggregate对聚集的工作流程进行了定义,大致有三个步骤:初始化,处理和返回结果。对应调用 方法是initialize ,update、merge和eval

* General work flow:
*
* Stage 1: initialize aggregate buffer object.
*
*   1. The framework calls `initialize(buffer: MutableRow)` to set up the empty aggregate buffer.
*   2. In `initialize`, we call `createAggregationBuffer(): T` to get the initial buffer object,
* and set it to the global buffer row.
*
*
* Stage 2: process input rows.
*
* If the aggregate mode is `Partial` or `Complete`:
*     1. The framework calls `update(buffer: MutableRow, input: InternalRow)` to process the input
* row.
*     2. In `update`, we get the buffer object from the global buffer row and call
* `update(buffer: T, input: InternalRow): Unit`.
*
* If the aggregate mode is `PartialMerge` or `Final`:
*     1. The framework call `merge(buffer: MutableRow, inputBuffer: InternalRow)` to process the
* input row, which are serialized buffer objects shuffled from other nodes.
*     2. In `merge`, we get the buffer object from the global buffer row, and get the binary data
* from input row and deserialize it to buffer object, then we call
* `merge(buffer: T, input: T): Unit` to merge these 2 buffer objects.
*
*
* Stage 3: output results.
*
* If the aggregate mode is `Partial` or `PartialMerge`:
*     1. The framework calls `serializeAggregateBufferInPlace` to replace the buffer object in the
* global buffer row with binary data.
*     2. In `serializeAggregateBufferInPlace`, we get the buffer object from the global buffer row
* and call `serialize(buffer: T): Array[Byte]` to serialize the buffer object to binary.
*     3. The framework outputs buffer attributes and shuffle them to other nodes.
*
* If the aggregate mode is `Final` or `Complete`:
*     1. The framework calls `eval(buffer: InternalRow)` to calculate the final result.
*     2. In `eval`, we get the buffer object from the global buffer row and call
* `eval(buffer: T): Any` to get the final result.
*     3. The framework outputs these final results.
*
*
* Window function work flow:
* The framework calls `update(buffer: MutableRow, input: InternalRow)` several times and then
* call `eval(buffer: InternalRow)`, so there is no need for window operator to call
* `serializeAggregateBufferInPlace`.
*
*
* NOTE: SQL with TypedImperativeAggregate functions is planned in sort based aggregation,
* instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation
* buffer's storage format, which is not supported by hash based aggregation. Hash based
* aggregation only support aggregation buffer of mutable types (like LongType, IntType that have
* fixed length and can be mutated in place in UnsafeRow).
* NOTE: The newly added ObjectHashAggregateExec supports TypedImperativeAggregate functions in
* hash based aggregation under some constraints.
*/

框架会维护一个全局缓冲区,这是一个巨大坑。

2.1原始代码

/**
   * Concat a list of elements.in a group.
   */
 @ExpressionDescription(
   usage = "_FUNC_(expr) - Concat a list of elements.in a group.")
 case class CollectGroupXMLPath(
                                 cols: Seq[Expression],
                                 mutableAggBufferOffset: Int = 0,
                                 inputAggBufferOffset: Int = 0) extends Collect {   def this(cols: Seq[Expression]) = this(cols, 0, 0)
  override val child = null
  override def children: Seq[Expression] = cols
  override def nullable: Boolean = true
  override def dataType: DataType = StringType
  override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)
   override def aggBufferAttributes: Seq[AttributeReference] = super.aggBufferAttributes
  override def checkInputDataTypes(): TypeCheckResult = {
     val allOK = cols.forall(child =>
       !child.dataType.existsRecursively(_.isInstanceOf[MapType]))
     if (allOK) {
       TypeCheckResult.TypeCheckSuccess
     } else {
       TypeCheckResult.TypeCheckFailure("group_xmlpath() cannot have map type data")
     }
   }  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
     copy(inputAggBufferOffset = newInputAggBufferOffset)  override def prettyName: String = "group_xmlpath"
  override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
  protected[this] val strbuffer: StringBuilder = new StringBuilder
  private var columnNames: Seq[(String, String)] = Seq.fill(cols.length)(("", ""))
  private var rootSpec = ("<root>", "</root>")
   private var rowSpec = ("<row>", "</row>")   override def initialize(b: InternalRow): Unit = {
     buffer.clear()    strbuffer.clear()
    initializeColNames
    strbuffer.append(rootSpec._1)
   }
   private def initializeColNames = {
     cols.last match {
       case Literal(v, d)  if d.isInstanceOf[ArrayType] =>
         val av = v.asInstanceOf[GenericArrayData]
         val names = av.array.map( _.toString.trim )
         val namepair = names.map(e => if ( e.length > 0 ) (s"<$e>", s"</$e>") else ("", "")).toSeq
         rootSpec = namepair(0)
         rowSpec = namepair(1)
         columnNames = namepair.slice(2, namepair.length)
       case _ =>
     }
   }  override def update(b: InternalRow, input: InternalRow): Unit = {
     strbuffer.append(rowSpec._1)
     for( i <- 0 to ( cols.length - 2) ) {
       strbuffer.append(columnNames(i)._1)
         .append(cols(i).eval(input))
         .append(columnNames(i)._2)
     }
     strbuffer.append(rowSpec._2)
   }  override def merge(buffer: InternalRow, input: InternalRow): Unit = {
     sys.error("group_xmlpath cannot be used in partial aggregations.")
   }  override def eval(input: InternalRow): Any = {
     strbuffer.append(rootSpec._2)
     UTF8String.fromString(strbuffer.toString())
   }
 }

之前代码用不了的原因就在于update和merge都会初始化缓冲区,即调用initialize方法。前一个版本的缓冲区是一个本地缓冲区,它的初始化都写在initialize方法中,因此后面的merge和update过程会清空本地缓冲区。尝试修改源码让直行流程不走merge阶段,但是会造成eval不调用,结果出错。在仔细研究接口注释后。决定从全局的缓冲区入手。

2.2 修改后代码

最后得到的代码如下:

case class CollectGroupXMLPath(
                                cols: Seq[Expression],
                                mutableAggBufferOffset: Int = 0,
                                inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {


  def this(cols: Seq[Expression]) = this(cols, 0, 0)

  override val child = cols.head

  override def children: Seq[Expression] = cols

  override def nullable: Boolean = true

  override def dataType: DataType = StringType

  //override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)


  // override def aggBufferAttributes: Seq[AttributeReference] = super.aggBufferAttributes

  override def checkInputDataTypes(): TypeCheckResult = {
    val allOK = cols.forall(child =>
      !child.dataType.existsRecursively(_.isInstanceOf[MapType]))
    if (allOK) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      TypeCheckResult.TypeCheckFailure("group_xmlpath() cannot have map type data")
    }
  }

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override def prettyName: String = "group_xmlpath"

  protected[this] val strbuffer: StringBuilder = new StringBuilder

  private var columnNames: Seq[(String, String)] = Seq.fill(cols.length)(("", ""))

  private var rootSpec = ("<root>", "</root>")
  private var rowSpec = ("<row>", "</row>")

  private[this] val anyObjectType = ObjectType(classOf[AnyRef])

  /* override def initialize(b: InternalRow): Unit = {
     buffer.clear()

     strbuffer.clear()

     initializeColNames

     strbuffer.append(rootSpec._1)
     createAggregationBuffer()
   }*/

  private def initializeColNames = {
    cols.last match {
      case Literal(v, d) if d.isInstanceOf[ArrayType] =>
        val av = v.asInstanceOf[GenericArrayData]
        val names = av.array.map(_.toString.trim)
        val namepair = names.map(e => if (e.length > 0) (s"<$e>", s"</$e>") else ("", "")).toSeq
        rootSpec = namepair(0)
        rowSpec = namepair(1)
        columnNames = namepair.slice(2, namepair.length)
      case _ =>
    }
  }

  override def update(b: InternalRow, input: InternalRow): Unit = {
    // Note: remember to clear local buffer first to avoid redundant data
    strbuffer.clear()
    strbuffer.append(rowSpec._1)
    for (i <- 0 to (cols.length - 2)) {
      strbuffer.append(columnNames(i)._1)
        .append(cols(i).eval(input))
        .append(columnNames(i)._2)
    }
    strbuffer.append(rowSpec._2)
    val out = InternalRow.fromSeq(Array(UTF8String.fromString(strbuffer.toString())))
    // force to merge input buffer into global buffer
    b(mutableAggBufferOffset) = getBufferObject(b) += out
  }


  private def getBufferObject(bufferRow: InternalRow): ArrayBuffer[Any] = {
    bufferRow.get(mutableAggBufferOffset, anyObjectType).asInstanceOf[ArrayBuffer[Any]]
  }

  override def merge(buffer: InternalRow, input: InternalRow): Unit = {
    super.merge(buffer, input)
  }

  override def eval(input: InternalRow): Any = {
    val head = input.toSeq(Seq(StringType)).head
    var buff = ArrayBuffer[UTF8String]()
    if (head.isInstanceOf[ArrayBuffer[UTF8String]])
      buff = head.asInstanceOf[ArrayBuffer[UTF8String]]
    val out = new mutable.StringBuilder()
    // reformat the out put
    out.append(rootSpec._1)
    val tmp = new mutable.StringBuilder()
    for (i <- 0 until buff.length) {
      out.append(tmp).append(buff(i)).toString()
      tmp.clear()
    }
    out.append(rootSpec._2)
    UTF8String.fromString(out.toString())
  }

  private lazy val projection = UnsafeProjection.create(
    Array[DataType](ArrayType(elementType = child.dataType, containsNull = false)))

  override def serialize(obj: mutable.ArrayBuffer[Any]): Array[Byte] = {
    val array = new GenericArrayData(Array(UTF8String.fromString(strbuffer.toString())))
    val bytes = projection.apply(InternalRow.apply(array)).getBytes()
    bytes
  }

  override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
}

3. 函数注册

最后,需要对函数进行注册,在FunctionRegistryd object 中添加一行

expression[CollectGroupXMLPath]("group_xmlpath"),

最后说几句题外话,有关sparksql调试的一些小技巧:

1. 查看物理计划。在sql前面加explain,打印物理计划做分析

2. 多println中间结果查看调用

3.异常调试法。搞不清调用关系的地方,可以抛个异常查看调用栈

4.断点。由于spark这种并行计算框架,断点时间过长会产生丢失心跳等。有时候不好用