百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 热门文章 > 正文

不要再做一个SQL Boy了,看看Hive UDAF怎么写

bigegpt 2024-09-20 14:03 4 浏览

经验丰富的后端往往是个优秀的CRUD工程师,而经验丰富的大数据工程师往往只是个SQL boy。

众所周知 Spark SQL 比较简单,学会了 SELECT SUM/COUNT(...) FROM table WHERE ... GROUP BY ...,外加上 if、case when、get_json_object、rank over (partition by ... order by ...)、lateral view explode 等函数,基本上能搞定绝大多数需求了。如果需要构造复杂对象,可以通过 array、map、named_struct、collect_list(set) 完成。Spark SQL 支持的完整的函数列表,可以在Google里搜索“spark builtin functions”找到。

如果只是做一个SQL Boy,那么了解上面这些使用方法,会Google就完事了。如果你想继续探索下技术,不妨了解下大数据框架的原理、阅读下源码、做一些实验验证自己的理解,可以参考前面几篇Spark相关的文章。如果是为了应对更为复杂的业务需求,但不脱离SQL框架,不妨了解一下 Hive User Defined Function 这一强大的工具。

User Defined Functions 简称 UDF,我们可以使用 Java 或 Python 实现自己的数据处理逻辑。Spark 和 Hive 各自有一套自己的 UDF 定义方式,但 Hive UDF 历史悠久,Spark 引擎也对其做了兼容,遇到问题网上教程也比较多,所以我们只介绍 Hive UDF。

Hive UDF 分类三类:

类型

特点

样例

UDF

输入一行,输出一行

+-*/, if, case when等

UDAF

输入多行,输出一行

count, sum, max, min, collect_list 等,与 group by 配合使用

UDTF

输入一行,输出多行

explode,可以与lateral view 配合使用

UDF 写起来比较简单,网上随便找个教程就可以。这里我们从一个例子开始,看一下 UDAF 怎么写。

背景

平时我们在刷抖音头条或者微信时,会在信息流里刷到视频或文章。app或web上一般会有很多埋点,用来将用户的行为收集到服务器,用于模型训练优化产品体验。以头条为例,当我们看到一个文章出现在信息流里时,埋点会触发一个曝光事件;当我们点开一篇文章时,会触发一个点击事件;如果我们点击的恰好是个视频,播放相关的事件也会通过埋点发送到服务器。服务器收到这些事件后,通常会发送到Kafka等消息队列服务,通过流式作业进行一系列的 Extract-Transform-Load ,其中包括数据清洗、添加反作弊标签、实时聚合等等。用户粒度数据落到 Hive 等离线存储用于离线分析,聚合以后的数据落到 ElasticSearch/MySQL 用于服务线上业务。

在进行离线分析时,我们通常会有一些初步聚合的中间表,通常包含我们需要的维度和指标。为了保持扩展性,我们将所有的指标放到同一个字段里。

比如下面这张Hive表,字段metric_data 包含了用户的曝光、点击、视频观看时长等多个指标,结构是 MAP<string, bigint>。

CREATE TABLE ad_stats_t (
  `uid` STRING, 
  `metric_data` MAP<STRING,BIGINT>
) using PARQUET

这张表是用户维度聚合的结果,如果我们想要获取大盘指标,则需要再做一次聚合:

select
  sum(metric_data['show_cnt']) as show_cnt,
  sum(metric_data['click_cnt']) as click_cnt
from ad_stats_t

如果我们想要更多指标,比如分享(share_cnt)、收藏(favorite_cnt) 等等,都需要更新上面的逻辑,更通用的方法是实现一个支持 map 聚合的UDAF,我们暂且叫它 sum_string_long_map。这个SQL只需要这么写,不管有多少指标,都可以动态支持:

select
  sum_string_long_map(metric_data) as metric_data
from ad_stats_t

如何实现

sum_string_long_map 和 sum 的逻辑框架完全一致,只是入参和出参不同,所以我们找找 Hive sum 是如何实现的。经过一番查找,在 Github上找到了名为 apache/hive 的仓库,git clone下来以后,文件 ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java 是 sum 函数的实现。

一个常规的 GenericUDAF 包含两个核心组件:Resolver 类 和 Evaluator 类。Resolver 类用于校验参数类型,然后将具体的业务逻辑代理给 Evaluator类是实现。

对于 SUM 函数而言,GenericUDAFSum 是Resolver 类,它继承了 AbstractGenericUDAFResolver,并且实现了 getEvaluator 方法;它有三个Evaluator类,都是GenericUDAFEvaluator的子类:

  • GenericUDAFSumHiveDecimal
  • GenericUDAFSumDouble
  • GenericUDAFSumLong

分别针对 Hive decical、浮点数和整数实现了SUM的逻辑。

成员方法 getEvaluator 用于解析接收参数的类型,然后将请求代理到 evaluator 类,我们先看这个方法的逻辑。

Hive 引擎在执行时,会将参数类型列表传给 getEvaluator 方法。由于 SUM 方法只接收一个参数,所以对于其他情况,我们可以抛出一个异常。

@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
    throws SemanticException {
  if (parameters.length != 1) {
    throw new UDFArgumentTypeException(parameters.length - 1,
        "Exactly one argument is expected.");
  }

  if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
    throw new UDFArgumentTypeException(0,
        "Only primitive type arguments are accepted but "
            + parameters[0].getTypeName() + " is passed.");
  }
  switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) {
  case BYTE:
  case SHORT:
  case INT:
  case LONG:
    return new GenericUDAFSumLong();
  case TIMESTAMP:
  case FLOAT:
  case DOUBLE:
  case STRING:
  case VARCHAR:
  case CHAR:
    return new GenericUDAFSumDouble();
  case DECIMAL:
    return new GenericUDAFSumHiveDecimal();
  case BOOLEAN:
  case DATE:
  default:
    throw new UDFArgumentTypeException(0,
        "Only numeric or string type arguments are accepted but "
            + parameters[0].getTypeName() + " is passed.");
  }
}

SUM 参数的类型必须是能够做加法的数值型,所以必须是基本类型 ObjectInspector.Category.PRIMITIVE。当然,除了基本类型,Hive还支持 LIST, MAP, STRUCT, UNION。

基本类型是一大类,它包括 byte /short /int /long /float /double 等等,我们通过 switch-case 返回对应的 Evaluator。

关于 Evaluator

GenericUDAFSumLong 实现了对整型的求和,展开这个类的结构。它有一个内部类 SumLongAgg,用来存储聚合结果。另外还有一组成员函数,均Override了父类里的定义。

我们按照调用时间的先后顺序,看一下这几个方法的作用:

方法

功能

init

Hive会调用该方法,用来创建一个UDAF evaluator实例

getNewAggregationBuffer

返回一个对象,用来存储临时的聚合结果。本case里是一个SumLongAgg对象

iterate

接收每一行数据,并更新到aggregationBuffer里

terminatePartial

返回一批数据的聚合结果,必须以能够持久化的类型存储。这些类型包括Java 基本类型(和包装类)、Array、Map、Hadoop Writables。注意:不能使用自定义的类型。

merge

将 terminatePartial 的返回结果合并到当前的聚合结果里

terminate

将最终的聚合结果返回

我们用 Spark Driver/Executor 的视角去看这些执行步骤的话。

在 Executor 端:

  • getNewAggregationBuffer 初始化存储聚合结果的对象,比如将agg.sum 初始化为0
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
  SumLongAgg result = new SumLongAgg();
  reset(result);
  return result;
}

@Override
public void reset(AggregationBuffer agg) throws HiveException {
  SumLongAgg myagg = (SumLongAgg) agg;
  myagg.empty = true;
  myagg.sum = 0;
}
  • iterate 方法接收一行数据,并合并到 agg
// 抽取以后的逻辑:
assert (parameters.length == 1);
Object partial = parameters[0]
if (partial != null) {
  SumLongAgg myagg = (SumLongAgg) agg;
  myagg.sum += PrimitiveObjectInspectorUtils.getLong(partial, inputOI);
  myagg.empty = false;
}
  • terminatePartial:数据累积到一定的行数,或者处理完一个文件,或者一个spark 分区时,会触发 terminatePartial。terminatePartial 的返回值必须是可序列化的,有可能通过shuffle发送给其他 executor或driver
// 抽取以后的逻辑:
SumLongAgg myagg = (SumLongAgg) agg;
if (myagg.empty) {
  return null;
}
result.set(myagg.sum);
return result;

在 Executor端执行 MapReduce里的 Map阶段,我们也可以称之为 PARTIAL1。在 Evaluator 父类里,总共定义了四个模式,分别是:

public static enum Mode {
  /**
   * PARTIAL1: from original data to partial aggregation data: iterate() and
   * terminatePartial() will be called.
   */
  PARTIAL1,
      /**
   * PARTIAL2: from partial aggregation data to partial aggregation data:
   * merge() and terminatePartial() will be called.
   */
  PARTIAL2,
      /**
   * FINAL: from partial aggregation to full aggregation: merge() and
   * terminate() will be called.
   */
  FINAL,
      /**
   * COMPLETE: from original data directly to full aggregation: iterate() and
   * terminate() will be called.
   */
  COMPLETE
};

PARTIAL2 对应的仍然是 MapReduce 的 MAP 阶段,在 executor 端执行,函数的调用顺序是:

  • merge: 将多个 terminatePartial 的返回结果聚合到一块。对于 SUM 而言,merge 和 iterate 的逻辑是一模一样的。
  • terminatePartial:将 partial 的结果合并到一起,但仍然不是全部的结果

FINAL 对应的是 MapReduce 的 REDUCE 阶段,调用顺序是:

  • merge:将多个 terminatePartial 的返回结果聚合到一块
  • terminate:返回最终聚合结果

对于 SUM 函数而言,由于入参和出参的类型是一样的,所以iterate 和 merge 的逻辑也是一样,terminatePartial 和 terminate 的逻辑也是一样的。

如果 terminatePartial 、terminate 的返回结果和入参的类型不一致,那么需要在 init 里根据 Mode m 判断是 PARTIAL1、PARTIAL2 或 FINAL/COMPLETE 分别设定对应的返回值类型,以支持正常的序列化和反序列化。有兴趣的朋友可以看一个更复杂的例子 GenericUDAFHistogramNumeric.java

实现 sum_string_long_map

前面提到过,sum_string_long_map 和 sum 的唯一区别是入参和出参的类型。我们先定义类的描述:

@Description(
  name     = "sum_string_long_map",
  value    = "_FUNC_( value) : Sum the long value by key in the map.",
  extended = "Example:\n" +
    "SELECT _FUNC_(metric_data) as metric_data FROM events;\n" +
    "(returns {\"a\":2, \"b\":1} if the value is [{\"a\":1, \"b\":1}, {\"a\":1}])\n\n"
)

定义 Resolver、Evaluator 的框架:

public class SumStringLongMapUDAF extends AbstractGenericUDAFResolver {
  static final Log LOG = LogFactory.getLog(SumStringLongMapUDAF.class.getName());

  @Override
  public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
    // type check goes here
    if (parameters.length != 1) {
      throw new UDFArgumentTypeException(parameters.length - 1, "Exactly one argument is expected.");
    }

    ObjectInspector oi = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[0]);
    if (oi.getCategory() != ObjectInspector.Category.MAP) {
      throw new UDFArgumentTypeException(0,
              "Argument must be MAP, but "
                      + oi.getCategory().name()
                      + " was passed.");
    }

    MapObjectInspector inputOI = (MapObjectInspector) oi;
    ObjectInspector keyOI = inputOI.getMapKeyObjectInspector();
    ObjectInspector valueOI = inputOI.getMapValueObjectInspector();

    if (keyOI.getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentTypeException(0,
              "Map key must be PRIMITIVE, but "
                      + keyOI.getCategory().name()
                      + " was passed.");
    }
    PrimitiveObjectInspector inputKeyOI = (PrimitiveObjectInspector) keyOI;
    if (inputKeyOI.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
      throw new UDFArgumentTypeException(0,
              "Map value must be STRING, but "
                      + inputKeyOI.getPrimitiveCategory().name()
                      + " was passed.");
    }

    if (valueOI.getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentTypeException(0,
              "Map value must be PRIMITIVE, but "
                      + valueOI.getCategory().name()
                      + " was passed.");
    }
    PrimitiveObjectInspector inputValueOI = (PrimitiveObjectInspector) valueOI;

    if (inputValueOI.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.LONG) {
      throw new UDFArgumentTypeException(0,
              "Map value must be LONG (BIGINT), but "
                      + inputValueOI.getPrimitiveCategory().name()
                      + " was passed.");
    }

    return new SumStringLongMapEvaluator();
  }

  public static class SumStringLongMapEvaluator extends GenericUDAFEvaluator {
    MapObjectInspector inputOI;
    ObjectInspector keyOI;
    ObjectInspector valueOI;
    MapObjectInspector outputOI;
    ...

可以看到 getEvaluator 方法里包含了大量的类型处理逻辑,类型校验包括:

  1. 接收参数必须是一个
  2. 接收参数类型必须是MAP
  3. 接收参数类型MAP的 key 必须是 String
  4. 接收参数类型MAP的 value 必须是Long (BIGINT)

terminatePartial 和 terminate 的返回结果都是 Map<String, Long>,所以 aggregationBuffer 的类型是:

static class SumMapAgg extends GenericUDAFEvaluator.AbstractAggregationBuffer { 
  HashMap<String, Long> resultMap;
}

在 iterate 或 merge 时,通过遍历所有的 key,对aggregationBuffer 对应key的value 进行更新:

public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
    if (parameters == null) {
        return;
    }
    assert (parameters.length == 1);

    SumMapAgg myagg = (SumMapAgg) agg;
    if (myagg == null) {
        return;
    }

    HashMap<String, Long> partialMap = (HashMap<String, Long>)inputOI.getMap(parameters[0]);
    partialMap.forEach((key,val)-> {
        Long baseValue = myagg.resultMap.getOrDefault(key, 0L);
        Long partialValue = partialMap.getOrDefault(key, 0L);
        myagg.resultMap.put(key, baseValue + partialValue);
    });
}

成品代码可以在 Github 的 oscarzhao/hive-udf repo里查看。

Spark SQL 里注册&调用该函数

在之前的文章Spark源码阅读:SparkSession类之spark对象的使用里,我们提到了如何在本地编译和运行 Spark,这里不做过多描述。

编译UDF的话,需要先下载 Github 的 oscarzhao/hive-udf repo。进入该目录:

# 依赖 JAVA 1.8
./build.sh
cd output
pwd

进入Spark所在的目录,启动 spark-shell 时,带上这个 jar 包:

./bin/spark-shell \
  --jars <code-folder>/hive-udf-1.0.0.jar

Spark-shell 启动成功后,可以在 WebUI 的 Environment Tab看到我们加载的 jar 包:


构造Hive表,并写入三条数据:

spark.sql("create table ad_stats_t (uid STRING, `metric_data` MAP<STRING,BIGINT>) using PARQUET")
spark.sql("""
insert overwrite table ad_stats_t
select 1 as uid, map('show_cnt', 10, 'click_cnt', 1) as metric_data
UNION ALL
select 2 as uid, map('show_cnt', 2)  as metric_data
UNION ALL
select 3 as uid, map('click_cnt', 1)  as metric_data
""")

注册 UDF:

spark.sql("create temporary function sum_string_long_map as 'com.demo.map_util.SumStringLongMapUDAF'")

我们通过一个 SQL 验证 UDF 逻辑是否正确:

val df = spark.sql("""select
  sum(metric_data['show_cnt']) as show_cnt,
  sum(metric_data['click_cnt']) as click_cnt,
  sum_string_long_map(metric_data) as metric_data
from ad_stats_t""")
df.collect

值得注意的是,我们在本地只构造了三条数据,可以快速验证程序的基本逻辑是否正常运行。但数据量太少,可能会触发 COMPLETE 模式,只触发 iterate 和 terminate 方法,并不能保证 merge 和 terminatePartial 的正确性。我们仍然需要一个有多个文件的大数据集合验证逻辑的正确性。

另外,对于这个case,我们没有支持 bigint 以外的其他类型,如果要支持的话,顺着 sum 的实现比葫芦画瓢抄就行了。

关于 UDAF,我从三个地方获取了关键的信息:

  1. Google 搜索 "GenericUDAFCaseStudy",第一个结果就是Apache Hive 的 wiki
  2. Github apache/hive 下有大量UDAF的代码实现,都是经过验证和peer-review的优质代码
  3. Google 搜索 "A Complete Guide to Writing Hive UDF" 关于 dataiku 的结果

相关推荐

pyproject.toml到底是什么东西?(py trim)

最近,在Twitter上有一个Python项目的维护者,他的项目因为构建失败而出现了一些bug(这个特别的项目不提供wheel,只提供sdist)。最终,发现这个bug是由于这个项目使用了一个pypr...

BDP服务平台SDK for Python3发布(bdp数据平台)

下载地址https://github.com/imysm/opends-sdk-python3.git说明最近在开发和bdp平台有关的项目,用到了bdp的python的sdk,但是官方是基于p...

Python-for-Android (p4a):(python-for-android p4a windows)

一、Python-for-Android(p4a)简介Python-for-Android(p4a),一个强大的开发工具,能够将你的Python应用程序打包成可在Android设备上运行...

Qt for Python—Qt Designer 概览

前言本系列第三篇文章(QtforPython学习笔记—应用程序初探)、第四篇文章(QtforPython学习笔记—应用程序再探)中均是使用纯代码方式来开发PySide6GUI应用程序...

Python:判断质数(jmu-python-判断质数)

#Python:判断质数defisPrime(n):foriinrange(2,n):ifn%i==0:return0re...

为什么那么多人讨厌Python(为什么python这么难)

Python那么棒,为什么那么多人讨厌它呢?我整理了一下,主要有这些原因:用缩进替代大括号许多人抱怨Python完全依赖于缩进来创建代码块,代码多一点就很难看到函数在哪里结束,那么你就需要把一个函数拆...

一文了解 Python 中带有 else 的循环语句 for-else/while-else

在本文中,我们将向您介绍如何在python中使用带有else的for/while循环语句。可能许多人对循环和else一起使用感到困惑,因为在if-else选择结构中else正常...

python的numpy向量化语句为什么会比for快?

我们先来看看,python之类语言的for循环,和其它语言相比,额外付出了什么。我们知道,python是解释执行的。举例来说,执行x=1234+5678,对编译型语言,是从内存读入两个shor...

开眼界!Python遍历文件可以这样做

来源:【公众号】Python技术Python对于文件夹或者文件的遍历一般有两种操作方法,一种是至二级利用其封装好的walk方法操作:import osfor root,d...

告别简单format()!Python Formatter类让你的代码更专业

Python中Formatter类是string模块中的一个重要类,它实现了Python字符串格式化的底层机制,允许开发者创建自定义的格式化行为。通过深入理解Formatter类的工作原理和使用方法,...

python学习——038如何将for循环改写成列表推导式

在Python里,列表推导式是一种能够简洁生成列表的表达式,可用于替换普通的for循环。下面是列表推导式的基本语法和常见应用场景。基本语法result=[]foriteminite...

详谈for循环和while循环的区别(for循环语句与while循环语句有什么区别)

初九,潜龙勿用在刚开始使用python循环语句时,经常会遇到for循环和while循环的混用,不清楚该如何选择;今天就对这2个循环语句做深入的分析,让大家更好地了解这2个循环语句以方便后续学习的加深。...

Python编程基础:循环结构for和while

Python中的循环结构包括两个,一是遍历循环(for循环),一是条件循环(while循环)。遍历循环遍历循环(for循环)会挨个访问序列或可迭代对象的元素,并执行里面的代码块。foriinra...

学习编程第154天 python编程 for循环输出菱形图

今天学习的是刘金玉老师零基础Python教程第38期,主要内容是python编程for循环输出菱形※。(一)利用for循环输出菱形形状的*号图形1.思路:将菱形分解为上下两个部分三角形图案,分别利用...

python 10个堪称完美的for循环实践

在Python中,for循环的高效使用能显著提升代码性能和可读性。以下是10个堪称完美的for循环实践,涵盖数据处理、算法优化和Pythonic编程风格:1.遍历列表同时获取索引(enumerate...