经验丰富的后端往往是个优秀的CRUD工程师,而经验丰富的大数据工程师往往只是个SQL boy。
众所周知 Spark SQL 比较简单,学会了 SELECT SUM/COUNT(...) FROM table WHERE ... GROUP BY ...,外加上 if、case when、get_json_object、rank over (partition by ... order by ...)、lateral view explode 等函数,基本上能搞定绝大多数需求了。如果需要构造复杂对象,可以通过 array、map、named_struct、collect_list(set) 完成。Spark SQL 支持的完整的函数列表,可以在Google里搜索“spark builtin functions”找到。
如果只是做一个SQL Boy,那么了解上面这些使用方法,会Google就完事了。如果你想继续探索下技术,不妨了解下大数据框架的原理、阅读下源码、做一些实验验证自己的理解,可以参考前面几篇Spark相关的文章。如果是为了应对更为复杂的业务需求,但不脱离SQL框架,不妨了解一下 Hive User Defined Function 这一强大的工具。
User Defined Functions 简称 UDF,我们可以使用 Java 或 Python 实现自己的数据处理逻辑。Spark 和 Hive 各自有一套自己的 UDF 定义方式,但 Hive UDF 历史悠久,Spark 引擎也对其做了兼容,遇到问题网上教程也比较多,所以我们只介绍 Hive UDF。
Hive UDF 分类三类:
类型 | 特点 | 样例 |
UDF | 输入一行,输出一行 | +-*/, if, case when等 |
UDAF | 输入多行,输出一行 | count, sum, max, min, collect_list 等,与 group by 配合使用 |
UDTF | 输入一行,输出多行 | explode,可以与lateral view 配合使用 |
UDF 写起来比较简单,网上随便找个教程就可以。这里我们从一个例子开始,看一下 UDAF 怎么写。
背景
平时我们在刷抖音头条或者微信时,会在信息流里刷到视频或文章。app或web上一般会有很多埋点,用来将用户的行为收集到服务器,用于模型训练优化产品体验。以头条为例,当我们看到一个文章出现在信息流里时,埋点会触发一个曝光事件;当我们点开一篇文章时,会触发一个点击事件;如果我们点击的恰好是个视频,播放相关的事件也会通过埋点发送到服务器。服务器收到这些事件后,通常会发送到Kafka等消息队列服务,通过流式作业进行一系列的 Extract-Transform-Load ,其中包括数据清洗、添加反作弊标签、实时聚合等等。用户粒度数据落到 Hive 等离线存储用于离线分析,聚合以后的数据落到 ElasticSearch/MySQL 用于服务线上业务。
在进行离线分析时,我们通常会有一些初步聚合的中间表,通常包含我们需要的维度和指标。为了保持扩展性,我们将所有的指标放到同一个字段里。
比如下面这张Hive表,字段metric_data 包含了用户的曝光、点击、视频观看时长等多个指标,结构是 MAP<string, bigint>。
CREATE TABLE ad_stats_t (
`uid` STRING,
`metric_data` MAP<STRING,BIGINT>
) using PARQUET
这张表是用户维度聚合的结果,如果我们想要获取大盘指标,则需要再做一次聚合:
select
sum(metric_data['show_cnt']) as show_cnt,
sum(metric_data['click_cnt']) as click_cnt
from ad_stats_t
如果我们想要更多指标,比如分享(share_cnt)、收藏(favorite_cnt) 等等,都需要更新上面的逻辑,更通用的方法是实现一个支持 map 聚合的UDAF,我们暂且叫它 sum_string_long_map。这个SQL只需要这么写,不管有多少指标,都可以动态支持:
select
sum_string_long_map(metric_data) as metric_data
from ad_stats_t
如何实现
sum_string_long_map 和 sum 的逻辑框架完全一致,只是入参和出参不同,所以我们找找 Hive sum 是如何实现的。经过一番查找,在 Github上找到了名为 apache/hive 的仓库,git clone下来以后,文件 ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java 是 sum 函数的实现。
一个常规的 GenericUDAF 包含两个核心组件:Resolver 类 和 Evaluator 类。Resolver 类用于校验参数类型,然后将具体的业务逻辑代理给 Evaluator类是实现。
对于 SUM 函数而言,GenericUDAFSum 是Resolver 类,它继承了 AbstractGenericUDAFResolver,并且实现了 getEvaluator 方法;它有三个Evaluator类,都是GenericUDAFEvaluator的子类:
- GenericUDAFSumHiveDecimal
- GenericUDAFSumDouble
- GenericUDAFSumLong
分别针对 Hive decical、浮点数和整数实现了SUM的逻辑。
成员方法 getEvaluator 用于解析接收参数的类型,然后将请求代理到 evaluator 类,我们先看这个方法的逻辑。
Hive 引擎在执行时,会将参数类型列表传给 getEvaluator 方法。由于 SUM 方法只接收一个参数,所以对于其他情况,我们可以抛出一个异常。
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
throws SemanticException {
if (parameters.length != 1) {
throw new UDFArgumentTypeException(parameters.length - 1,
"Exactly one argument is expected.");
}
if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentTypeException(0,
"Only primitive type arguments are accepted but "
+ parameters[0].getTypeName() + " is passed.");
}
switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) {
case BYTE:
case SHORT:
case INT:
case LONG:
return new GenericUDAFSumLong();
case TIMESTAMP:
case FLOAT:
case DOUBLE:
case STRING:
case VARCHAR:
case CHAR:
return new GenericUDAFSumDouble();
case DECIMAL:
return new GenericUDAFSumHiveDecimal();
case BOOLEAN:
case DATE:
default:
throw new UDFArgumentTypeException(0,
"Only numeric or string type arguments are accepted but "
+ parameters[0].getTypeName() + " is passed.");
}
}
SUM 参数的类型必须是能够做加法的数值型,所以必须是基本类型 ObjectInspector.Category.PRIMITIVE。当然,除了基本类型,Hive还支持 LIST, MAP, STRUCT, UNION。
基本类型是一大类,它包括 byte /short /int /long /float /double 等等,我们通过 switch-case 返回对应的 Evaluator。
关于 Evaluator
GenericUDAFSumLong 实现了对整型的求和,展开这个类的结构。它有一个内部类 SumLongAgg,用来存储聚合结果。另外还有一组成员函数,均Override了父类里的定义。
我们按照调用时间的先后顺序,看一下这几个方法的作用:
方法 | 功能 |
init | Hive会调用该方法,用来创建一个UDAF evaluator实例 |
getNewAggregationBuffer | 返回一个对象,用来存储临时的聚合结果。本case里是一个SumLongAgg对象 |
iterate | 接收每一行数据,并更新到aggregationBuffer里 |
terminatePartial | 返回一批数据的聚合结果,必须以能够持久化的类型存储。这些类型包括Java 基本类型(和包装类)、Array、Map、Hadoop Writables。注意:不能使用自定义的类型。 |
merge | 将 terminatePartial 的返回结果合并到当前的聚合结果里 |
terminate | 将最终的聚合结果返回 |
我们用 Spark Driver/Executor 的视角去看这些执行步骤的话。
在 Executor 端:
- getNewAggregationBuffer 初始化存储聚合结果的对象,比如将agg.sum 初始化为0
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
SumLongAgg result = new SumLongAgg();
reset(result);
return result;
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
SumLongAgg myagg = (SumLongAgg) agg;
myagg.empty = true;
myagg.sum = 0;
}
- iterate 方法接收一行数据,并合并到 agg
// 抽取以后的逻辑:
assert (parameters.length == 1);
Object partial = parameters[0]
if (partial != null) {
SumLongAgg myagg = (SumLongAgg) agg;
myagg.sum += PrimitiveObjectInspectorUtils.getLong(partial, inputOI);
myagg.empty = false;
}
- terminatePartial:数据累积到一定的行数,或者处理完一个文件,或者一个spark 分区时,会触发 terminatePartial。terminatePartial 的返回值必须是可序列化的,有可能通过shuffle发送给其他 executor或driver
// 抽取以后的逻辑:
SumLongAgg myagg = (SumLongAgg) agg;
if (myagg.empty) {
return null;
}
result.set(myagg.sum);
return result;
在 Executor端执行 MapReduce里的 Map阶段,我们也可以称之为 PARTIAL1。在 Evaluator 父类里,总共定义了四个模式,分别是:
public static enum Mode {
/**
* PARTIAL1: from original data to partial aggregation data: iterate() and
* terminatePartial() will be called.
*/
PARTIAL1,
/**
* PARTIAL2: from partial aggregation data to partial aggregation data:
* merge() and terminatePartial() will be called.
*/
PARTIAL2,
/**
* FINAL: from partial aggregation to full aggregation: merge() and
* terminate() will be called.
*/
FINAL,
/**
* COMPLETE: from original data directly to full aggregation: iterate() and
* terminate() will be called.
*/
COMPLETE
};
PARTIAL2 对应的仍然是 MapReduce 的 MAP 阶段,在 executor 端执行,函数的调用顺序是:
- merge: 将多个 terminatePartial 的返回结果聚合到一块。对于 SUM 而言,merge 和 iterate 的逻辑是一模一样的。
- terminatePartial:将 partial 的结果合并到一起,但仍然不是全部的结果
FINAL 对应的是 MapReduce 的 REDUCE 阶段,调用顺序是:
- merge:将多个 terminatePartial 的返回结果聚合到一块
- terminate:返回最终聚合结果
对于 SUM 函数而言,由于入参和出参的类型是一样的,所以iterate 和 merge 的逻辑也是一样,terminatePartial 和 terminate 的逻辑也是一样的。
如果 terminatePartial 、terminate 的返回结果和入参的类型不一致,那么需要在 init 里根据 Mode m 判断是 PARTIAL1、PARTIAL2 或 FINAL/COMPLETE 分别设定对应的返回值类型,以支持正常的序列化和反序列化。有兴趣的朋友可以看一个更复杂的例子 GenericUDAFHistogramNumeric.java
实现 sum_string_long_map
前面提到过,sum_string_long_map 和 sum 的唯一区别是入参和出参的类型。我们先定义类的描述:
@Description(
name = "sum_string_long_map",
value = "_FUNC_( value) : Sum the long value by key in the map.",
extended = "Example:\n" +
"SELECT _FUNC_(metric_data) as metric_data FROM events;\n" +
"(returns {\"a\":2, \"b\":1} if the value is [{\"a\":1, \"b\":1}, {\"a\":1}])\n\n"
)
定义 Resolver、Evaluator 的框架:
public class SumStringLongMapUDAF extends AbstractGenericUDAFResolver {
static final Log LOG = LogFactory.getLog(SumStringLongMapUDAF.class.getName());
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
// type check goes here
if (parameters.length != 1) {
throw new UDFArgumentTypeException(parameters.length - 1, "Exactly one argument is expected.");
}
ObjectInspector oi = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[0]);
if (oi.getCategory() != ObjectInspector.Category.MAP) {
throw new UDFArgumentTypeException(0,
"Argument must be MAP, but "
+ oi.getCategory().name()
+ " was passed.");
}
MapObjectInspector inputOI = (MapObjectInspector) oi;
ObjectInspector keyOI = inputOI.getMapKeyObjectInspector();
ObjectInspector valueOI = inputOI.getMapValueObjectInspector();
if (keyOI.getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentTypeException(0,
"Map key must be PRIMITIVE, but "
+ keyOI.getCategory().name()
+ " was passed.");
}
PrimitiveObjectInspector inputKeyOI = (PrimitiveObjectInspector) keyOI;
if (inputKeyOI.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
throw new UDFArgumentTypeException(0,
"Map value must be STRING, but "
+ inputKeyOI.getPrimitiveCategory().name()
+ " was passed.");
}
if (valueOI.getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentTypeException(0,
"Map value must be PRIMITIVE, but "
+ valueOI.getCategory().name()
+ " was passed.");
}
PrimitiveObjectInspector inputValueOI = (PrimitiveObjectInspector) valueOI;
if (inputValueOI.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.LONG) {
throw new UDFArgumentTypeException(0,
"Map value must be LONG (BIGINT), but "
+ inputValueOI.getPrimitiveCategory().name()
+ " was passed.");
}
return new SumStringLongMapEvaluator();
}
public static class SumStringLongMapEvaluator extends GenericUDAFEvaluator {
MapObjectInspector inputOI;
ObjectInspector keyOI;
ObjectInspector valueOI;
MapObjectInspector outputOI;
...
可以看到 getEvaluator 方法里包含了大量的类型处理逻辑,类型校验包括:
- 接收参数必须是一个
- 接收参数类型必须是MAP
- 接收参数类型MAP的 key 必须是 String
- 接收参数类型MAP的 value 必须是Long (BIGINT)
terminatePartial 和 terminate 的返回结果都是 Map<String, Long>,所以 aggregationBuffer 的类型是:
static class SumMapAgg extends GenericUDAFEvaluator.AbstractAggregationBuffer {
HashMap<String, Long> resultMap;
}
在 iterate 或 merge 时,通过遍历所有的 key,对aggregationBuffer 对应key的value 进行更新:
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
if (parameters == null) {
return;
}
assert (parameters.length == 1);
SumMapAgg myagg = (SumMapAgg) agg;
if (myagg == null) {
return;
}
HashMap<String, Long> partialMap = (HashMap<String, Long>)inputOI.getMap(parameters[0]);
partialMap.forEach((key,val)-> {
Long baseValue = myagg.resultMap.getOrDefault(key, 0L);
Long partialValue = partialMap.getOrDefault(key, 0L);
myagg.resultMap.put(key, baseValue + partialValue);
});
}
成品代码可以在 Github 的 oscarzhao/hive-udf repo里查看。
Spark SQL 里注册&调用该函数
在之前的文章Spark源码阅读:SparkSession类之spark对象的使用里,我们提到了如何在本地编译和运行 Spark,这里不做过多描述。
编译UDF的话,需要先下载 Github 的 oscarzhao/hive-udf repo。进入该目录:
# 依赖 JAVA 1.8
./build.sh
cd output
pwd
进入Spark所在的目录,启动 spark-shell 时,带上这个 jar 包:
./bin/spark-shell \
--jars <code-folder>/hive-udf-1.0.0.jar
Spark-shell 启动成功后,可以在 WebUI 的 Environment Tab看到我们加载的 jar 包:
构造Hive表,并写入三条数据:
spark.sql("create table ad_stats_t (uid STRING, `metric_data` MAP<STRING,BIGINT>) using PARQUET")
spark.sql("""
insert overwrite table ad_stats_t
select 1 as uid, map('show_cnt', 10, 'click_cnt', 1) as metric_data
UNION ALL
select 2 as uid, map('show_cnt', 2) as metric_data
UNION ALL
select 3 as uid, map('click_cnt', 1) as metric_data
""")
注册 UDF:
spark.sql("create temporary function sum_string_long_map as 'com.demo.map_util.SumStringLongMapUDAF'")
我们通过一个 SQL 验证 UDF 逻辑是否正确:
val df = spark.sql("""select
sum(metric_data['show_cnt']) as show_cnt,
sum(metric_data['click_cnt']) as click_cnt,
sum_string_long_map(metric_data) as metric_data
from ad_stats_t""")
df.collect
值得注意的是,我们在本地只构造了三条数据,可以快速验证程序的基本逻辑是否正常运行。但数据量太少,可能会触发 COMPLETE 模式,只触发 iterate 和 terminate 方法,并不能保证 merge 和 terminatePartial 的正确性。我们仍然需要一个有多个文件的大数据集合验证逻辑的正确性。
另外,对于这个case,我们没有支持 bigint 以外的其他类型,如果要支持的话,顺着 sum 的实现比葫芦画瓢抄就行了。
关于 UDAF,我从三个地方获取了关键的信息:
- Google 搜索 "GenericUDAFCaseStudy",第一个结果就是Apache Hive 的 wiki
- Github apache/hive 下有大量UDAF的代码实现,都是经过验证和peer-review的优质代码
- Google 搜索 "A Complete Guide to Writing Hive UDF" 关于 dataiku 的结果