除了处理任何类型的值,Spark还允许我们创建以下分组类型:
- 最简单的分组是通过在select语句中执行聚合来汇总完整的DataFrame。
- “group by”允许指定一个或多个key以及一个或多个聚合函数来转换值列。
- “window”使您能够指定一个或多个key以及一个或多个聚合函数来转换值列。然而,函数的行输入与当前行有某种关联。
- 一个“grouping set”,您可以使用它在多个不同级别聚合。grouping set可以在SQL中作为原语使用,也可以通过DataFrame中的rollups 和 cubes使用。
- “rollup”使您能够指定一个或多个键以及一个或多个聚合函数来转换值列,这些值列将按层次结构汇总。
- “cube”允许您指定一个或多个键以及一个或多个聚合函数来转换值列,这些值列将在所有列的组合中进行汇总。
每个分组返回一个RelationalGroupedDataset,我们在这个数据集中指定聚合。
提示需要考虑的一个重要问题是,你需要一个多么精确的答案。在对大数据进行计算时,获得一个问题的准确答案可能相当昂贵,而仅仅要求一个合理精确度的近似值往往要便宜得多。您将注意到,我们在整本书中都提到了一些近似函数,这通常是提高Spark作业的速度和执行的好机会,特别是对于交互式和在线分析。 |
让我们从读取关于购买的数据开始,重新分区数据,使其拥有更少的分区(因为我们知道它是存储在许多小文件中的小数据量),并缓存结果以便快速访问:
// in Scala
val df = spark.read.format("csv")
.option("header", "true")
.option("inferSchema", "true")
.load("/data/retail-data/all/*.csv")
.coalesce(5)
df.cache()
df.createOrReplaceTempView("dfTable")
# in Python
df = spark.read.format("csv")\
.option("header", "true")\
.option("inferSchema", "true")\
.load("/data/retail-data/all/*.csv")\
.coalesce(5)
df.cache()
df.createOrReplaceTempView("dfTable")
这里有一个数据样本,你可以参考一些函数的输出:
如前所述,基本聚合应用于整个DataFrame。最简单的例子是count方法:
df.count() == 541909
如果你一章一章地读这本书,你就会知道count实际上是一个Action,而不是一个transformation,所以它会立即返回。您可以使用count了解数据集的总体大小,但是另一种常见的模式是使用它在内存中缓存整个DataFrame,就像我们在本例中所做的那样。现在,这个方法有点离群,因为它作为一个方法(在本例中)而不是函数存在,并且被热切地求值,而不是一个延迟转换。在下一节中,我们还将看到count被用作一个延迟函数。
7.1. 聚合函数
除了可以在DataFrames或.stat中出现的特殊情况之外,所有聚合都可以作为函数使用,如我们在第6章中看到的那样。您可以在org.apache.spark.sql.functions包中找到大多数聚合函数。
提示可用的SQL函数与我们可以用Scala和Python导入的函数之间存在一些差距。这将随着每个spark版本会发生改变,所以不可能包含一个确定的列表。本节介绍最常见的函数。
|
7.1.1. Count
第一个值得讨论的函数是count,但在本例中它将作为transformation而不是action。在这种情况下,我们可以做以下两件事之一:指定要计数的特定列,或者使用count(*)或count(1)来表示我们希望将每一行计数为字面常量列,如下例所示:
// in Scala
import org.apache.spark.sql.functions.count
df.select(count("StockCode")).show() // 541909
# in Python
from pyspark.sql.functions import count
df.select(count("StockCode")).show() # 541909
-- in SQL
SELECT COUNT(*) FROM dfTable
注意
当涉及到null值和计数时,有许多问题。例如,当执行count(*)时,Spark将计数null值(包括包含所有null的行)。但是,在计算单个列时,Spark不会计算空值。
7.1.2. countDistinct
有时,总数并不相关;而是你想要的唯一组的个数。要得到这个数字,可以使用countDistinct函数。这对于单独的列来说更相关:
// in Scala
import org.apache.spark.sql.functions.countDistinct
df.select(countDistinct("StockCode")).show() // 4070
# in Python
from pyspark.sql.functions import countDistinct
df.select(countDistinct("StockCode")).show() # 4070
-- in SQL
SELECT COUNT(DISTINCT *) FROM DFTABLE
7.1.3. approx_count_distinct
通常,我们发现自己处理的是大型数据集,而精确的不同计数是不相关的。在某些情况下,近似到一定程度的精度会工作得很好,为此,您可以使用approx_count_distinct函数:
// in Scala
import org.apache.spark.sql.functions.approx_count_distinct
df.select(approx_count_distinct("StockCode", 0.1)).show() // 3364
# in Python
from pyspark.sql.functions import approx_count_distinct
df.select(approx_count_distinct("StockCode", 0.1)).show() # 3364
-- in SQL
SELECT approx_count_distinct(StockCode, 0.1) FROM DFTABLE
您将注意到,approx_count_distinct使用了另一个参数,您可以使用该参数指定允许的最大估计错误。在本例中,我们指定了一个相当大的错误,因此接收到的答案非常遥远,但是完成得比countDistinct快。使用更大的数据集,您将看到更大的性能提升。
7.1.4. first 和 last
通过使用这两个明显命名的函数,可以从DataFrame获得第一个和最后一个值。这将基于DataFrame中的行,而不是基于DataFrame中的值:
// in Scala
import org.apache.spark.sql.functions.{first, last}
df.select(first("StockCode"), last("StockCode")).show()
# in Python
from pyspark.sql.functions import first, last
df.select(first("StockCode"), last("StockCode")).show()
-- in SQL
SELECT first(StockCode), last(StockCode) FROM dfTable
7.1.5. min和max
要从DataFrame中提取最小值和最大值,可以使用min和max函数:
// in Scala
import org.apache.spark.sql.functions.{min, max}
df.select(min("Quantity"), max("Quantity")).show()
# in Python
from pyspark.sql.functions import min, max
df.select(min("Quantity"), max("Quantity")).show()
-- in SQL
SELECT min(Quantity), max(Quantity) FROM dfTable
7.1.6. Sum
另一个简单的任务是使用sum函数相加一行中的所有值:
// in Scala
import org.apache.spark.sql.functions.sum
df.select(sum("Quantity")).show() // 5176450
# in Python
from pyspark.sql.functions import sum
df.select(sum("Quantity")).show() # 5176450
-- in SQL
SELECT sum(Quantity) FROM dfTable
1.1.1. sumDistinct
除了对总数求和,还可以使用sumDistinct函数对一组不同的值求和:
// in Scala
import org.apache.spark.sql.functions.sumDistinct
df.select(sumDistinct("Quantity")).show() // 29310
# in Python
from pyspark.sql.functions import sumDistinct
df.select(sumDistinct("Quantity")).show() # 29310
-- in SQL
SELECT SUM(Quantity) FROM dfTable -- 29310
7.1.8. avg
虽然可以通过求和除以计数来计算平均值,但是Spark提供了一种更简单的方法来通过avg或平均值函数获得平均值。在本例中,我们使用alias,以便以后更容易地重用这些列:
// in Scala
import org.apache.spark.sql.functions.{sum, count, avg, expr}
df.select(
count("Quantity").alias("total_transactions"),
sum("Quantity").alias("total_purchases"),
avg("Quantity").alias("avg_purchases"),
expr("mean(Quantity)").alias("mean_purchases"))
.selectExpr(
"total_purchases/total_transactions",
"avg_purchases",
"mean_purchases").show()
# in Python
from pyspark.sql.functions import sum, count, avg, expr
df.select(
count("Quantity").alias("total_transactions"),
sum("Quantity").alias("total_purchases"),
avg("Quantity").alias("avg_purchases"),
expr("mean(Quantity)").alias("mean_purchases"))\
.selectExpr(
"total_purchases/total_transactions",
"avg_purchases",
"mean_purchases").show()
提示
还可以通过指定distinct来平均所有不同的值。事实上,大多数聚合函数只支持在不同的值上执行此操作。
7.1.9. 方差和标准差
计算均值很自然地会带来方差和标准差的问题。这两个指标都是衡量数据在均值附近的分布情况。
方差:
标准差是方差的平方根。您可以在Spark中使用它们各自的函数计算这些值。但是,需要注意的是Spark既有样本标准差的公式,也有总体标准差的公式。这些是根本不同的统计公式,我们需要区别对待。默认情况下,如果使用variance o或stddev函数,Spark执行样本标准差或方差的公式。您还可以显式地指定这些或参考总体标准差或方差:
// in Scala
import org.apache.spark.sql.functions.{var_pop, stddev_pop}
import org.apache.spark.sql.functions.{var_samp, stddev_samp}
df.select(var_pop("Quantity"), var_samp("Quantity"),
stddev_pop("Quantity"), stddev_samp("Quantity")).show()
# in Python
from pyspark.sql.functions import var_pop, stddev_pop
from pyspark.sql.functions import var_samp, stddev_samp
df.select(var_pop("Quantity"), var_samp("Quantity"),
stddev_pop("Quantity"), stddev_samp("Quantity")).show()
-- in SQL
SELECT var_pop(Quantity), var_samp(Quantity),
stddev_pop(Quantity), stddev_samp(Quantity)
FROM dfTable
7.1.10. 偏态与峰度skewnessand kurtosis
偏度和峰度都是数据中极值点的度量。偏度衡量的是数据在均值附近的不对称性,而峰度衡量的是数据尾部的不对称性。在将数据建模为随机变量的概率分布时,这两者都是相关的。虽然这里我们不会详细讨论这些定义背后的数学原理,但是您可以在internet上很容易地查找定义。你可以用以下函数来计算:
import org.apache.spark.sql.functions.{skewness, kurtosis}
df.select(skewness("Quantity"), kurtosis("Quantity")).show()
# in Python
from pyspark.sql.functions import skewness, kurtosis
df.select(skewness("Quantity"), kurtosis("Quantity")).show()
-- in SQL
SELECT skewness(Quantity), kurtosis(Quantity) FROM dfTable
7.1.11. 协方差和相关系数
我们讨论了单列聚合,但是一些函数比较了两个不同列中值的交互。其中两个函数是cov和corr,分别表示协方差和相关性。相关性度量皮尔逊相关系数,该系数在-1和+1之间缩放。协方差是根据数据中的输入进行缩放的。同var函数一样,协方差可以计算为样本协方差,也可以计算为总体协方差。因此,指定要使用哪个公式非常重要。相关性没有这个概念,因此没有计算总体或样本。它们是这样工作的:
// in Scala
import org.apache.spark.sql.functions.{corr, covar_pop, covar_samp}
df.select(corr("InvoiceNo", "Quantity"), covar_samp("InvoiceNo", "Quantity"),
covar_pop("InvoiceNo", "Quantity")).show()
# in Python
from pyspark.sql.functions import corr, covar_pop, covar_samp
df.select(corr("InvoiceNo", "Quantity"), covar_samp("InvoiceNo", "Quantity"),
covar_pop("InvoiceNo", "Quantity")).show()
-- in SQL
SELECT corr(InvoiceNo, Quantity), covar_samp(InvoiceNo, Quantity),
covar_pop(InvoiceNo, Quantity)
FROM dfTable
7.1.12. 聚合为复杂类型
在Spark中,您不仅可以使用公式执行数值的聚合,还可以对复杂类型执行它们。例如,我们可以收集给定列中存在的值列表,或者只收集到一个集合中的唯一值。您可以使用它来执行更多的编程访问稍后在管道或通过一个用户定义函数(UDF)的整个集合:
// in Scala
import org.apache.spark.sql.functions.{collect_set, collect_list}
df.agg(collect_set("Country"), collect_list("Country")).show()
# in Python
from pyspark.sql.functions import collect_set, collect_list
df.agg(collect_set("Country"), collect_list("Country")).show()
-- in SQL
SELECT collect_set(Country), collect_set(Country) FROM dfTable
7.2. 分组
到目前为止,我们只执行了DataFrame级别的聚合。更常见的任务是基于数据中的(groups)组执行计算。这通常是在分类数据上完成的,我们将数据分组在一列上,并对该组中的其他列执行一些计算。最好的解释方法是开始执行一些分组。第一个是计数,和之前一样。我们将按每个唯一的发票编号进行分组,并得到发票上的项目数量。注意,这将返回另一个DataFrame,并延迟执行。我们分两个阶段进行分组。首先指定要分组的列,然后指定聚合。第一步返回RelationalGroupedDataset,第二步返回DataFrame。如前所述,我们可以指定任意数量的列来分组:
df.groupBy("InvoiceNo", "CustomerId").count().show()
-- in SQL
SELECT count(*) FROM dfTable GROUP BY InvoiceNo, CustomerId
表达式分组
正如我们前面看到的,计数是一种特殊的情况,因为它作为一种方法存在。为此,通常我们更喜欢使用count函数。我们没有将该函数作为表达式传递到select语句中,而是在agg中指定它。这使得您可以传入任何只需要指定一些聚合的表达式。您甚至可以在转换列后使用别名列,以便以后在DataFrame中使用:
// in Scala
import org.apache.spark.sql.functions.count
df.groupBy("InvoiceNo").agg(
count("Quantity").alias("quan"),
expr("count(Quantity)")).show()
# in Python
from pyspark.sql.functions import count
df.groupBy("InvoiceNo").agg(
count("Quantity").alias("quan"),
expr("count(Quantity)")).show()
Map分组
有时,可以更容易地将转换指定为一系列map,其中键为列,值为要执行的聚合函数(作为字符串)。如果以内联方式指定多个列名,也可以重用它们:
// in Scala
df.groupBy("InvoiceNo").agg("Quantity"->"avg", "Quantity"->"stddev_pop").show()
# in Python
df.groupBy("InvoiceNo").agg(expr("avg(Quantity)"),expr("stddev_pop(Quantity)"))\
.show()
-- in SQL
SELECT avg(Quantity), stddev_pop(Quantity), InvoiceNo FROM dfTable
GROUP BY InvoiceNo
7.3. 窗口函数
您还可以使用窗口函数来执行一些独特的聚合,方法是在特定的数据“窗口”上计算一些聚合,通过使用对当前数据的引用来定义这些数据。此窗口规范确定将向此函数传递哪些行。这有点抽象,可能类似于标准的group by,让我们进一步区分它们。group by接受数据,并且每一行只能进入一个组。窗口函数根据一组行(称为frame)为表的每个输入行计算返回值。每一行可以属于一个或多个frame。一个常见的用例是查看某个值的滚动平均值,其中每一行代表一天。如果你这样做,每一行将会在7个不同的坐标系中结束。稍后我们将介绍如何定义frame,但是为了便于您参考,Spark支持三种窗口函数:排序函数、分析函数和聚合函数。图7-1说明了如何将给定的行分成多个frame。
为了演示,我们将添加一个日期列,该列将把我们的发票日期转换为一个只包含日期信息(而不包含时间信息)的列:
// in Scala
import org.apache.spark.sql.functions.{col, to_date}
val dfWithDate = df.withColumn("date", to_date(col("InvoiceDate"),
"MM/d/yyyy H:mm"))
dfWithDate.createOrReplaceTempView("dfWithDate")
# in Python
from pyspark.sql.functions import col, to_date
dfWithDate = df.withColumn("date", to_date(col("InvoiceDate"), "MM/d/yyyy H:mm"))
dfWithDate.createOrReplaceTempView("dfWithDate")
窗口函数的第一步是创建一个窗口规范。注意,partition by与我们到目前为止介绍的分区方案概念无关。这只是一个类似的概念,描述我们将如何分割我们的小组。排序决定给定分区内的排序,最后,Frame规范(rowsBetween语句)根据对当前输入行的引用声明哪些行将包含在框架中。
// in Scala
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.col
val windowSpec = Window
.partitionBy("CustomerId", "date")
.orderBy(col("Quantity").desc)
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
# in Python
from pyspark.sql.window import Window
from pyspark.sql.functions import desc
windowSpec = Window\
.partitionBy("CustomerId", "date")\
.orderBy(desc("Quantity"))\
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
现在我们要使用聚合函数来了解每个特定客户的更多信息。一个例子可能是建立所有时间的最大购买量。为了回答这个问题,我们使用与前面通过传递列名或表达式所看到的相同的聚合函数。此外,我们指定了窗口规范,定义了该函数将应用于哪些数据frame:
import org.apache.spark.sql.functions.max
val maxPurchaseQuantity = max(col("Quantity")).over(windowSpec)
# in Python
from pyspark.sql.functions import max
maxPurchaseQuantity = max(col("Quantity")).over(windowSpec)
您将看到,这将返回一个列(或表达式)。现在我们可以在DataFrame select语句中使用它。在此之前,我们将创建购买数量排名。为此,我们使用dense_rank函数来确定哪个日期具有每个客户的最大购买量。我们使用dense_rank而不是rank来避免在有绑定值(或者在我们的例子中是重复的行)时,在排序序列中出现间隔:
// in Scala
import org.apache.spark.sql.functions.{dense_rank, rank}
val purchaseDenseRank = dense_rank().over(windowSpec)
val purchaseRank = rank().over(windowSpec)
# in Python
from pyspark.sql.functions import dense_rank, rank
purchaseDenseRank = dense_rank().over(windowSpec)
purchaseRank = rank().over(windowSpec)
它还返回一个列,我们可以在select语句中使用该列。现在我们可以执行一个select来查看计算出的窗口值:
// in Scala
import org.apache.spark.sql.functions.col
dfWithDate.where("CustomerId IS NOT NULL").orderBy("CustomerId")
.select(
col("CustomerId"),
col("date"),
col("Quantity"),
purchaseRank.alias("quantityRank"),
purchaseDenseRank.alias("quantityDenseRank"),
maxPurchaseQuantity.alias("maxPurchaseQuantity")).show()
# in Python
from pyspark.sql.functions import col
dfWithDate.where("CustomerId IS NOT NULL").orderBy("CustomerId")\
.select(
col("CustomerId"),
col("date"),
col("Quantity"),
purchaseRank.alias("quantityRank"),
purchaseDenseRank.alias("quantityDenseRank"),
maxPurchaseQuantity.alias("maxPurchaseQuantity")).show()
-- in SQL
SELECT CustomerId, date, Quantity,
rank(Quantity) OVER (PARTITION BY CustomerId, date
ORDER BY Quantity DESC NULLS LAST
ROWS BETWEEN
UNBOUNDED PRECEDING AND
CURRENT ROW) as rank,
dense_rank(Quantity) OVER (PARTITION BY CustomerId, date
ORDER BY Quantity DESC NULLS LAST
ROWS BETWEEN
UNBOUNDED PRECEDING AND
CURRENT ROW) as dRank,
max(Quantity) OVER (PARTITION BY CustomerId, date
ORDER BY Quantity DESC NULLS LAST
ROWS BETWEEN
UNBOUNDED PRECEDING AND
CURRENT ROW) as maxPurchase
FROM dfWithDate WHERE CustomerId IS NOT NULL ORDER BY CustomerId
7.4. GroupingSets(未完待续......)
到目前为止,在本章中,我们已经看到了简单的group-by表达式,我们可以使用这些表达式在一组列上聚合这些列中的值。然而,有时我们需要更完整的东西—跨多个组的聚合。我们通过使用分组集来实现这一点。分组集是用于将集合集组合在一起的低级工具。它们使您能够在它们的group-by语句中创建任意聚合。让我们通过一个例子来更好地理解。在这里,我们想要得到所有股票代码和客户的总数量。为此,我们将使用以下SQL表达式:
// in Scala
val dfNoNull = dfWithDate.drop()
dfNoNull.createOrReplaceTempView("dfNoNull")
# in Python
dfNoNull = dfWithDate.drop()
dfNoNull.createOrReplaceTempView("dfNoNull")
-- in SQL
SELECT CustomerId, stockCode, sum(Quantity) FROM dfNoNull
GROUP BY customerId, stockCode
ORDER BY CustomerId DESC, stockCode DESC
你可以用一个分组集(GROUPING SETS)做同样的事情:
-- in SQL
SELECT CustomerId, stockCode, sum(Quantity) FROM dfNoNull
GROUP BY customerId, stockCode GROUPING SETS((customerId, stockCode))
ORDER BY CustomerId DESC, stockCode DESC
警告
分组集依赖于聚合级别的空值。如果不过滤空值,将得到不正确的结果。这适用于多维数据集、滚动和分组集。
很简单,但是如果您还想包含项目的总数,而不管客户代码或库存代码是什么呢?使用传统的分组声明,这是不可能的。但是,使用分组集很简单:我们只需指定我们也希望在分组集中聚合到那个级别。这实际上是几个不同分组的联合:
-- in SQL
SELECT CustomerId, stockCode, sum(Quantity) FROM dfNoNull
GROUP BY customerId, stockCode GROUPING SETS((customerId, stockCode),())
ORDER BY CustomerId DESC, stockCode DESC
分组集操作符仅在SQL中可用。要在DataFrames中执行相同的操作,您可以使用rollup和cube操作符—它们允许我们获得相同的结果。我们来看一下。
Rollups
到目前为止,我们一直在研究显式分组。当我们设置多个列的分组键时,Spark会查看这些键以及数据集中可见的实际组合。rollup是一个多维聚合,它为我们执行各种分组样式的计算。让我们创建一个汇总,跨越时间(与我们的新日期列)和空间(国家列),创建一个新的DataFrame包括总数除以所有日期,每个日期的总和DataFrame,和每个国家的小计DataFrame日期:
val rolledUpDF = dfNoNull.rollup("Date", "Country").agg(sum("Quantity"))
.selectExpr("Date", "Country", "`sum(Quantity)` as total_quantity")
.orderBy("Date")
rolledUpDF.show()
# in Python
rolledUpDF = dfNoNull.rollup("Date", "Country").agg(sum("Quantity"))\
.selectExpr("Date", "Country", "`sum(Quantity)` as total_quantity")\
.orderBy("Date")
rolledUpDF.show()
现在你看到空值的地方就是你找到总数的地方。两个rollup列中的null指定这两个列的总金额:
Cube
多维数据集将rollup提升到更深的级别。多维数据集不是分层处理元素,而是跨所有维度执行相同的操作。这意味着它不仅会按时间顺序排列,还会按国家顺序排列。为了再次提出这个问题,您能制作一个包含以下内容的表吗?
- The total across all dates and countries
- The total for each date across all countries
- The total for each country on each date
- The total for each country across all dates
方法调用非常类似,但是我们不调用rollup,而是调用cube:
// in Scala
dfNoNull.cube("Date", "Country").agg(sum(col("Quantity")))
.select("Date", "Country", "sum(Quantity)").orderBy("Date").show()
# in Python
from pyspark.sql.functions import sum
dfNoNull.cube("Date", "Country").agg(sum(col("Quantity")))\
.select("Date", "Country", "sum(Quantity)").orderBy("Date").show()
这是对表中几乎所有信息的一个快速且容易访问的摘要,这是创建一个其他人稍后可以使用的快速摘要表的好方法。
Grouping Metadata
有时,在使用多维数据集和滚动时,您希望能够查询聚合级别,以便能够轻松地相应地过滤它们。我们可以使用grouping_id来实现这一点,它会给我们一个列,指定我们的结果集中的聚合级别。表7–1 group id的用途
Grouping id | 描述 |
3 | 这将出现在最高级别的聚合中,无论customerId和stockCode是什么,它都会给出总数量。 |
2 | 这将出现在所有单个股票代码的聚合中。这给出了每个股票代码的总数量,而不考虑客户。 |
1 | 这将给出我们每个客户的总数量,无论购买的项目。 |
0 | 这将给出单个customerId和stockCode组合的总数量。 |
这有点抽象,所以很有必要自己尝试理解这种行为:
// in Scala
import org.apache.spark.sql.functions.{grouping_id, sum, expr}
dfNoNull.cube("customerId", "stockCode").agg(grouping_id(), sum("Quantity"))
.orderBy(expr("grouping_id()").desc)
.show()
Pivot
数据透视pivot使您能够将行转换为列。例如,在当前数据中有一个Country列。使用pivot,我们可以根据这些给定国家的某些功能进行聚合,并以一种易于查询的方式显示它们:
// in Scala
val pivoted = dfWithDate.groupBy("date").pivot("Country").sum()
# in Python
pivoted = dfWithDate.groupBy("date").pivot("Country").sum()
现在,这个DataFrame将为每个国家、数值变量和指定日期的列的组合提供一个列。例如,对于美国,我们有以下列:USA_sum(Quantity)、USA_sum(UnitPrice)、USA_sum(CustomerID)。这表示数据集中每个数字列对应一个列(因为我们刚刚对所有列执行了聚合)。下面是一个例子查询和结果从这个数据:
现在,所有列都可以通过单个分组计算,但是pivot的值取决于您希望如何研究数据。如果某个列的基数足够低,可以将其转换为列,以便用户可以看到模式并立即知道要查询什么,那么它可能很有用。
7.5. 用户自定义聚合函数
用户定义聚合函数(UDAFs)是用户基于自定义公式或业务规则定义自己的聚合函数的一种方法。您可以使用UDAFs计算输入数据组上的自定义计算(而不是单行)。Spark维护一个单独的AggregationBuffer来存储每组输入数据的中间结果。要创建UDAF,您必须继承UserDefinedAggregateFunction基类,并实现以下方法:
- inputSchema将输入参数表示为StructType
- bufferSchema将中间UDAF结果表示为StructType
- dataType表示返回dataType
- deterministic是一个布尔值,它指定这个UDAF是否会为给定的输入返回相同的结果
- initialize允许初始化聚合缓冲区的值
- update描述如何基于给定行更新内部缓冲区
- merge描述如何合并两个聚合缓冲区
- evaluate将生成聚合的最终结果
下面的示例实现了一个BoolAnd,它将告诉我们(对于给定列)是否所有行都为真;如果不是,则返回false:
// in Scala
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
class BoolAnd extends UserDefinedAggregateFunction {
def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("value", BooleanType) :: Nil)
def bufferSchema: StructType = StructType(
StructField("result", BooleanType) :: Nil
)
def dataType: DataType = BooleanType
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = true
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Boolean](0) && input.getAs[Boolean](0)
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Boolean](0) && buffer2.getAs[Boolean](0)
}
def evaluate(buffer: Row): Any = {
buffer(0)
}
}
现在,我们只需实例化我们的类并/或将它注册为一个函数:
// in Scala
val ba = new BoolAnd spark.udf.register("booland", ba)
import org.apache.spark.sql.functions._
spark.range(1)
.selectExpr("explode(array(TRUE, TRUE, TRUE)) as t")
.selectExpr("explode(array(TRUE, FALSE, TRUE)) as f", "t")
.select(ba(col("t")), expr("booland(f)"))
.show()
UDAFs目前只能在Scala或Java中使用。然而,在Spark 2.3中,您还可以通过注册函数来调用Scala或Java UDF和UDAFs,正如我们在第6章的UDF部分中所示。更多信息,请访问SPARK-19439。
7.6. 结束语
本章介绍了可以在Spark中执行的不同类型和类型的聚合。您学习了简单的组到窗口函数以及滚动和多维数据集。第8章将讨论如何执行连接来将不同的数据源组合在一起。