在Spark中,也支持Hive中的自定义函数。自定义函数大致可以分为三种:
- UDF(User-Defined-Function),即最基本的自定义函数,类似to_char,to_date等
- UDAF(User- Defined Aggregation Funcation),用户自定义聚合函数,类似在group by之后使用的sum,avg
- UDTF(User-Defined Table-Generating Functions),用户自定义生成函数,有点像stream里面的flatMap
一、自定义UDF 拼接三个参数,
1.1继承org.apache.spark.sql.api.java.UDFxx(1-22);
1.2、实现call方法
@Override
public String call(Long v1, String v2, String split) throws Exception {
return String.valueOf(v1) + split + v2;
}
完整代码实现
package com.chb.shopanalysis.hive.UDF;
import org.apache.spark.sql.api.java.UDF3;
/**
* 自定义UDF
* 1 上海 split
* 拼接成"1:上海"
* 将两个字段拼接起来(使用指定的分隔符)
* @author chb
*
*/
public class ConcatLongStringUDF implements UDF3<Long, String, String, String> {
private static final long serialVersionUID = 1L;
@Override
public String call(Long v1, String v2, String split) throws Exception {
return String.valueOf(v1) + split + v2;
}
}
1.4、注册函数
// 注册自定义函数
sqlContext.udf().register(
"concat_long_string", //自定义函数的名称
new ConcatLongStringUDF(), //自定义UDF对象
DataTypes.StringType); //返回数据类型
1.5、使用函数
/**
* 从hive表中读取数据, 使用自定义聚合函数
*/
private static void readProductClickInfo() {
// 可以获取到每个area下的每个product_id的城市信息拼接起来的串
String sql =
"SELECT city_id, city_name,"
+ "area,"
+ "product_id,"
+ "concat_long_string(city_id,city_name,':') city_infos "
+ "FROM click_product_basic ";
// 使用Spark SQL执行这条SQL语句
DataFrame df = sqlContext.sql(sql);
//展示结果
df.show();
}
二、用户自定义聚合函数UDAF
2.1、继承org.apache.spark.sql.expressions.UserDefinedAggregateFunction
2.2、定义输入,缓存,输出字段类型
// 指定输入数据的字段与类型
private StructType inputSchema = DataTypes.createStructType(Arrays.asList(
DataTypes.createStructField("cityInfo", DataTypes.StringType, true)));
// 指定缓冲数据的字段与类型
private StructType bufferSchema = DataTypes.createStructType(Arrays.asList(
DataTypes.createStructField("bufferCityInfo", DataTypes.StringType, true)));
// 指定返回类型
private DataType dataType = DataTypes.StringType;
2.3、deterministic()
决定每次相同输入,是否返回相同输出, 一般都会设置为true.
@Override
//每次相同的输入是否返回相同的输出
public boolean deterministic() {
return deterministic;
}
2.4、初始化
/**
* 初始化
* 可以认为是,你自己在内部指定一个初始的值
*/
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, "");
}
2.5、更新, 这个是组类根据自己的逻辑进行拼接, 然后更新数据
/**
* 更新
* 可以认为是,一个一个地将组内的字段值传递进来
* 实现拼接的逻辑
*/
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
// 缓冲中的已经拼接过的城市信息串
String bufferCityInfo = buffer.getString(0);
// 刚刚传递进来的某个城市信息
String cityInfo = input.getString(0);
// 在这里要实现去重的逻辑
// 判断:之前没有拼接过某个城市信息,那么这里才可以接下去拼接新的城市信息
if(!bufferCityInfo.contains(cityInfo)) {
if("".equals(bufferCityInfo)) {
bufferCityInfo += cityInfo;
} else {
// 比如1:北京
//2:上海
//结果 1:北京,2:上海
//再 来一个 1:北京 就不会拼接进去。
bufferCityInfo += "," + cityInfo;
}
buffer.update(0, bufferCityInfo);
}
}
2.6、合并, 将所有节点的数据进行合并
/**
* 合并
* update操作,可能是针对一个分组内的部分数据,在某个节点上发生的
* 但是可能一个分组内的数据,会分布在多个节点上处理
* 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
*/
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
String bufferCityInfo1 = buffer1.getString(0);
String bufferCityInfo2 = buffer2.getString(0);
for(String cityInfo : bufferCityInfo2.split(",")) {
if(!bufferCityInfo1.contains(cityInfo)) {
if("".equals(bufferCityInfo1)) {
bufferCityInfo1 += cityInfo;
} else {
bufferCityInfo1 += "," + cityInfo;
}
}
}
buffer1.update(0, bufferCityInfo1);
}
2.7、输出最终结果, 可能我们需要的输出格式,可以在该方法中,进行格式化。
@Override
//计算出最终结果
public Object evaluate(Row row) {
return row.getString(0);
}
2.8、注册函数
sqlContext.udf().register("group_concat_distinct",
new GroupConcatDistinctUDAF());
2.9、使用
/**
* 从hive表中读取数据, 使用自定义聚合函数
*/
private static void readProductClickInfo() {
// 按照area和product_id两个字段进行分组
// 计算出各区域各商品的点击次数
// 可以获取到每个area下的每个product_id的城市信息拼接起来的串
String sql = "SELECT area, product_id,"
+ "count(*) click_count, "
+ "group_concat_distinct(concat_long_string(city_id,city_name,':')) city_infos "
+ "FROM click_product_basic "
+ "GROUP BY area,product_id ";
// 使用Spark SQL执行这条SQL语句
DataFrame df = sqlContext.sql(sql);
df.show();
// 再次将查询出来的数据注册为一个临时表
// 各区域各商品的点击次数(以及额外的城市列表)
df.registerTempTable("tmp_area_product_click_count");
}