前言
不得不说,udf函数在spark开发中是非常方便的。有了这个提供,我们不仅可以操作spark dataframe。还可以直接操作数仓(hive)而无需再去过多精力研究hive的复杂函数。
值得高兴的是pyspark同样也支持udf的编写,我们知道初期的spark对于python并不十分友好,随着版本的更新也给python提供了更多的接口。
udf函数的编写
这个其实就是把python的函数绑定spark的udf函数中。我们先来初始化我们的数据源。
df = sc.parallelize([[30, '小明'],
[40, '小明'],
[80, '小明'],
[80, '小强'],
[30, '小强'],
[40, '小强'],
[80, '小强'],
[80, '小强']
]).toDF(['name', 'score'])
df.show()
+-----+----+
|score|name|
+-----+----+
| 30|小明|
| 40|小明|
| 80|小明|
| 80|小强|
| 30|小强|
| 40|小强|
| 80|小强|
| 80|小强|
+-----+----+
我们假设上述数据是两位同学本学年的成绩。现在我们想查看他们的成绩是否及格,我们可以通过map函数去操作。但是此时如果数据在数仓内,或者我们习惯spark sql 的用法。我们不想再进行过多rdd的转换。我们则可以通过spark udf函数来操作。
函数编写
from pyspark.sql.types import IntegerType, FloatType, StringType
from pyspark.sql.functions import udf, collect_list
def is_pass(line):
return "及格" if line >= 60 else "不及格"
# 进行udf函数绑定, 第一个参数是函数名, 第二个函数是返回值类型。
get_score_pass = udf(is_pass, StringType())
# 注册该函数到spark session中
spark.udf.register(name="get_score_pass", f=get_score_pass)
# 进行查询
df.registerTempTable('test_score')
spark.sql(
"select score, name, get_score_pass(score) as `是否及格` from test_score"
).show()
+-----+----+--------+
|score|name|是否及格|
+-----+----+--------+
| 30|小明| 不及格|
| 40|小明| 不及格|
| 80|小明| 及格|
| 80|小强| 及格|
| 30|小强| 不及格|
| 40|小强| 不及格|
| 80|小强| 及格|
| 80|小强| 及格|
+-----+----+--------+
注意事项
我们在绑定udf函数后需要进行注册。
spark.udf.register(name="get_score_pass", f=get_score_pass)
否则则会报以下错误
"Undefined function: 'get_score_pass'. This function is neither a registered temporary function nor a permanent function
聚合函数的编写
上述是普通udf函数的编写,但在我们生产中很多时候需要使用聚合函数。虽然spark也提供了对标主流数据库大多通用类聚合函数,但是还是有一些特殊场景需要我们自定义聚合函数。
这个时候udf就是一个不错的选择。
比如上述的数据中我们要求中每一位同学的及格率,这个时候普通的聚合函数一次就很难做到了,我们可以编写udf函数。
聚合函数前使用
我们先需要把聚合到的字段放到一个列表里,比如上述的成绩
spark.sql(
"select collect_list(score), name from test_score group by name"
).show()
+--------------------+----+
| collect_list(score)|name|
+--------------------+----+
| [30, 40, 80]|小明|
|[80, 30, 40, 80, 80]|小强|
+--------------------+----+
这样我们的所有成绩聚合后的结果就聚合在一个列表里,后面就好操作了
函数编写
还是同样的三部曲,函数书写、函数绑定、函数注册
# 编写函数
def is_pass(line):
length = len(line)
total = 0
for i in line:
if i >= 60:
total += 1
return total/length
# 函数绑定
get_pass_ratio = udf(is_pass, StringType())
# 函数注册
spark.udf.register(name="get_pass_ratio", f=get_pass_ratio)
spark.sql(
"select collect_list(score), get_pass_ratio(collect_list(score)) as `及格率`, name
from test_score group by name"
).show()
+--------------------+------------------+----+
| collect_list(score)| 及格率|name|
+--------------------+------------------+----+
| [30, 40, 80]|0.3333333333333333|小明|
|[80, 30, 40, 80, 80]| 0.6|小强|
+--------------------+------------------+----+