背景

本文基于spark 3.2.0
由于codegen涉及到的知识点比较多,我们先来说清楚code"""""",我们暂且叫做code代码块

scala 字符串插值

要想搞清楚spark的code代码块,就得现搞清楚scala 字符串插值。
scala 字符串插值是2.10.0版本引用进来的新语法规则,可以直接允许使用者将变量引用直接插入到字符串中,如下:

val name = 'LI'
println(s"My name is $name")
输出:
My name is LI

这种资料很多,大家自行查阅资料理解。

code代码块

因为这块代码比较复杂,直接拿出例子来运行:
直接找到spark CastSuite.scala 第215行如下:

test("cast string to boolean II") {
    checkEvaluation(cast("abc", BooleanType), null)

之后在javaCode.scala 输出对应的想要debug的值,如下:

*/
    def code(args: Any*): Block = {
      sc.checkLengths(args)
      if (sc.parts.length == 0) {
        EmptyBlock
      } else {
        args.foreach {
          case _: ExprValue | _: Inline | _: Block =>
          case _: Boolean | _: Byte | _: Int | _: Long | _: Float | _: Double | _: String =>
          case other => throw QueryExecutionErrors.cannotInterpolateClassIntoCodeBlockError(other)
        }

        val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args)
        // scalasytle:off
        println(s"code: $codeParts")
        println(s"blockInputs: $blockInputs")
        // scalasytle:on

        CodeBlock(codeParts, blockInputs)
      }
    }

这样,运行后我们会发现,如下结果:

code: ArrayBuffer(
          if (org.apache.spark.sql.catalyst.util.StringUtils.isTrueString(, )) {
            ,  = true;
          } else if (org.apache.spark.sql.catalyst.util.StringUtils.isFalseString(, )) {
            ,  = false;
          } else {
            isNull_0 = true;
          }
        )
blockInputs: ArrayBuffer(((UTF8String) references[0] /* literal */), value_0, ((UTF8String) references[0] /* literal */), value_0)
result: if (org.apache.spark.sql.catalyst.util.StringUtils.isTrueString(((UTF8String) references[0] /* literal */))) {
            value_0 = true;
          } else if (org.apache.spark.sql.catalyst.util.StringUtils.isFalseString(((UTF8String) references[0] /* literal */))) {
            value_0 = false;
          } else {
            isNull_0 = true;
          }
...

而这段代码刚好和Cast.scala中的 castToBooleanCode方法是一一对应的的:

private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
    case StringType =>
      val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}"
      (c, evPrim, evNull) =>
        val castFailureCode = if (ansiEnabled) {
          s"throw QueryExecutionErrors.invalidInputSyntaxForBooleanError($c);"
        } else {
          s"$evNull = true;"
        }
        val result = code"""
          if ($stringUtils.isTrueString($c)) {
            $evPrim = true;
          } else if ($stringUtils.isFalseString($c)) {
            $evPrim = false;
          } else {
            $castFailureCode
          }
        """
        // scalastyle:off
        println(s"result: $result")
        // scalastyle:on
      result

也就是说spark自定义的ExprValue类型的值被替换了(其实是Inline/Block/ExprValue这三种类型的值都会被替换,只不过这里没有体现),如下:

x

x

evPrim

被替换成了((UTF8String) references[0] /* literal */)

c

被替换成了value_0

而输出的result结果就是拼接完后的完整字符串。
我们这里是为了debug,才会把结果和对应的片段打印出来,
而在spark真正处理的时候,返回的是ExprCode类型的值,在真正需要代码生成的时候,才会调用的toString的方法生成对应的字符串

code代码块之间的连接

但是我们在Cast.scala的方法中我们看到的doGenCode是先调用child.genCode的方法的:

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val eval = child.genCode(ctx)
    val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)

    ev.copy(code = eval.code +
      castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast))
  }

那子节点的ExprCode怎么和父节点的ExprCode连接起来的呢?
其实这个和写代码的思路是一样的,每个子节点返回的ExprCode类型的值,都会对应为该方法体的的实现代码,返回值(包括了类型),spark额外增加了一个是否为null,如下:

case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue)

其中code是对应的方法体的实现代码,
isNull 是对应的是否为null,
value 代表的返回值

至于为什么会额外增加一个是否为null,还是和写代码的逻辑是一样的,因为只有不为空的情况下,代码才会正常的往下运行:

protected[this] def castCode(ctx: CodegenContext, input: ExprValue, inputIsNull: ExprValue,
    result: ExprValue, resultIsNull: ExprValue, resultType: DataType, cast: CastFunction): Block = {
    val javaType = JavaCode.javaType(resultType)
    code"""
      boolean $resultIsNull = $inputIsNull;
      $javaType $result = ${CodeGenerator.defaultValue(resultType)};
      if (!$inputIsNull) {
        ${cast(input, result, resultIsNull)}
      }
    """
  }

这里的!$inputIsNull判断,只有不为空了才进行下一步的转换操作,要不然会抛出异常。

这样把子节点的结果作为父节点的入参传入给对应的方法,这样生成的代码完全符合编码的逻辑,这样这部分也就说完了,当然这部分也是代码生成的重中之重,理解了这部分,代码生成这块就差不多了,其他的就是各个部分的实现,用心去看即可。