1. 前言
因為最近工作中有需要自定義udf,所以本文記錄下最近所了解到的udf的知識。主要講述hive中如何自定義udf,至于udf一些原理性的東西,比如udf在mr過程中怎么起作用的,這個涉及到hive的細節,我也不清楚,所以本文不會涉及,知道多少寫多少吧。
2. UDF分類
hive中udf主要分為三類:
- 標準UDF
這種類型的udf每次接受的輸入是一行數據中的一個列或者多個列(下面我把這個一行的一列叫著一個單元吧,類似表格中的一個單元格),然后輸出是一個單元。比如abs, array,asin這種都是標準udf。
自定義標準函數需要繼承實現抽象類org.apache.hadoop.hive.ql.udf.generic.GenericUDF
- 自定義聚合函數(UDAF)
比如max,min這種函數都是hive內置聚合函數。聚合函數和標準udf的區別是:聚合函數需要接收多行輸入才能計算出結果,比如max就需要接收表中所有數據(或者group by中分組內所有數據)才能計算出最大值。
自定義聚合函數需要實現抽象類org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver
- 自定義表生成函數(UDTF)
上面1,2中的udf都只輸出一個標量的數據(一個單元)。表生成函數故名思義,其輸出有點像子查詢,可以一次輸出多行多列。
自定義表生成函數需要實現抽象類org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
。
2. 自定義UDF
引入maven依賴
<dependency>
<groupId>org.apache.hive</groupId>
<artifactId>hive-exec</artifactId>
<version>2.3.0</version>
</dependency>
2.1 自定義標準UDF
2.1.1 實現抽象類GenericUDF
該類的全路徑為:org.apache.hadoop.hive.ql.udf.generic.GenericUDF
1. 抽象類GenericUDF解釋
GenericUDF類如下:
public abstract class GenericUDF implements Closeable {
...
/* 實例化后initialize方法只會調用一次
- 參數arguments即udf接收的參數列表對應的objectinspector
- 返回的ObjectInspector對象就是udf返回值的對應的objectinspector
initialize方法中往往做的工作是檢查一下arguments是否和你udf需要的參數個數以及類型是否匹配。
*/
public abstract ObjectInspector initialize(ObjectInspector[] arguments)
throws UDFArgumentException;
...
// 真正的udf邏輯在這里實現
// - 參數arguments即udf函數輸入數據,這個數組的長度和initialize的參數長度一樣
//
public abstract Object evaluate(DeferredObject[] arguments)
throws HiveException;
}
GenericUDF有很多的方法,但是只有上面兩個抽象方法需要自己實現。
關于ObjectInspector,HIVE在傳遞數據時會包含數據本身以及對應的ObjectInspector,ObjectInspector中包含數據類型信息,通過oi去解析獲得數據。
2.1.2 實例
假設這里要實現下面這種功能標準udf:
cycle_range(col_name, num)
它的接收一列,以及一了整數值為參數,然后將這列轉換為一個index(index 屬于[0,num))到列值的映射,像下面這樣:
> SELECT cycle_range(name, 3) FROM src_table;
INDEX NAME
{1, "eric"}
{2, "aaron"}
{0, "john"}
{1, "marry"}
{2, "hellen"}
{0, "jerry"}
{1, "ellen"}
...
這里定義一個叫cycle_range的標準udf去實現列值的轉換,實現如下:
/**
這里使用注解描述udf信息,當使用beeline命令'describe function cycle_range'時,會輸出value中的介紹信息,其中_FUNC_會被替換成真實的udf名稱。
*/
@Description(name = "cycle_range",
value = "_FUNC_(x, num) - return a map containes an index as key and x as value",
extended = "Example:\n"
+ " > SELECT _FUNC_(x, 3) FROM src;\n"
+ "{i,x}, i in [0 - 3)\n"
)
public class GenericUDFRange extends GenericUDF {
// 第二個參數是一個整型常量,放在這里
private static LongWritable rangeNum = null;
// index 遞增并對rangeNum取模的結果
private static Long index = 0L;
// udf 返回值
private transient Map<Object,Object> ret = new HashMap<Object,Object>();
// udf參數整型常量可以是BYTE/SHORT/INT/LONG 這個converter將它們都轉換成long處理
private transient ObjectInspectorConverters.Converter rangeConverter;
// 在inittialize里檢查一下參數個數與類型
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
// 只接受兩個參數
if(arguments.length != 2){
throw new UDFArgumentException(
"RANGE() requires 2 arguments, got " + arguments.length
);
}
// 第二個參數必須是PRIMITIVE這一類的,這是hive sql內置的類型,可以對應到java的primitive type,此外還必須是BYTE/SHORT/INT/LONG之一
if(arguments[1].getCategory() != ObjectInspector.Category.PRIMITIVE){
throw new UDFArgumentException(
"RANGE() only take primitive Integer type, got " + arguments[1].getTypeName()
);
}
PrimitiveObjectInspector poi = (PrimitiveObjectInspector)arguments[1];
// 獲取到第二個參數的具體類型枚舉
PrimitiveObjectInspector.PrimitiveCategory rangeNumType = poi.getPrimitiveCategory();
ObjectInspector outputInspector = null;
switch (rangeNumType){
case BYTE:
case SHORT:
case INT:
case LONG:
// 以上4個case是合法類型,獲得一個converter將這四類都轉換成WritableLong處理
rangeConverter = ObjectInspectorConverters.getConverter(
arguments[1], PrimitiveObjectInspectorFactory.writableLongObjectInspector
);
// udf的輸出值的oi,輸出的是一個map對應的ObjectInspector, key是long,value還是原來的列的oi
outputInspector = ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.javaLongObjectInspector, arguments[0]);
return outputInspector;
default:
throw new UDFArgumentException(
"RANGE only takes BYTE/SHORT/INT/LONG types as the second arguments type, got " + arguments[1].getTypeName()
);
}
}
// 這里開始接收實際的一行一行的輸入數據,然后返回處理后的值
// deferredObjects應該包含兩個值,第一個值是列的值,第二個值是那個整型常量range值
public Object evaluate(DeferredObject[] deferredObjects) throws HiveException {
// 拿到range值
Object rangeObject = deferredObjects[1].get();
if(rangeNum == null){
rangeNum = new LongWritable();
// 用coverter都轉換成LongWritable,然后保存起來.
rangeObject = rangeConverter.convert(rangeObject);
rangeNum.set(Math.abs(((LongWritable)rangeObject).get()));
}
// 計算index 對range的模
index = (index + 1) % rangeNum.get();
// 由于ret是這個udf實例的成員,用來保存返回的map,而evaluate又會不停的調用,所以這里put前都會clear一下,保證始終只有當前處理后的返回值。
ret.clear();
// 設置返回值,返回
ret.put(index, deferredObjects[0].get());
return ret;
}
public String getDisplayString(String[] strings) {
return getStandardDisplayString("range", strings,",");
}
}
編寫好后:
- 打jar包,最好打fat jar,把依賴都打進去,假設我的jar包的路徑:"/Users/eric/udf-1.0-SNAPSHOT.jar"
- 在beeline 終端將jar加入hive的classpath:
add jar /Users/eric/udf-1.0-SNAPSHOT.jar
- 創建udf
create temporary function cycle_range as 'me.eric.udfs.GenericUDFRange'
成功后就可以使用了。
2.2 自定義聚合函數UDAF
2.2.1 實現抽象類AbstractGenericUDAFResolver
實現自定義UDAF首先要繼承并實現類AbstractGenericUDAFResolver
,有下面兩個方法:
public abstract class AbstractGenericUDAFResolver
implements GenericUDAFResolver2
{
@SuppressWarnings("deprecation")
@Override
public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info)
throws SemanticException {
if (info.isAllColumns()) {
throw new SemanticException(
"The specified syntax for UDAF invocation is invalid.");
}
return getEvaluator(info.getParameters());
}
/**
由于上面的getEvaluator也是調用的這個方法實現,所以只需要重寫著這個
getEvaluator即可。 udaf函數的主要邏輯不是getEvaluator方法里里完成的。
而是在其返回的GenericUDAFEvaluator中實現的,那么在getEvaluator方法中往往只需要根據參數info(info中保存了傳遞給udaf的實際參數信息)做一下udaf的參數類型檢查即可,
然后返回用戶自定義的GenericUDAFEvaluator。
*/
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] info)
throws SemanticException {
throw new SemanticException(
"This UDAF does not support the deprecated getEvaluator() method.");
}
上面介紹中說到GenericUDAFEvaluator才是真正實現udaf業務邏輯的地方,下面是GenericUDAFEvaluator抽象類的的實現:
public abstract class GenericUDAFEvaluator implements Closeable {
@Retention(RetentionPolicy.RUNTIME)
public static @interface AggregationType {
boolean estimable() default false;
}
...
public static enum Mode {
/**
* PARTIAL1: from original data to partial aggregation data: iterate() and
* terminatePartial() will be called.
*/
PARTIAL1,
/**
* PARTIAL2: from partial aggregation data to partial aggregation data:
* merge() and terminatePartial() will be called.
*/
PARTIAL2,
/**
* FINAL: from partial aggregation to full aggregation: merge() and
* terminate() will be called.
*/
FINAL,
/**
* COMPLETE: from original data directly to full aggregation: iterate() and
* terminate() will be called.
*/
COMPLETE
};
Mode mode;
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
// This function should be overriden in every sub class
// And the sub class should call super.init(m, parameters) to get mode set.
mode = m;
return null;
}
public abstract AggregationBuffer getNewAggregationBuffer() throws HiveException;
public abstract void reset(AggregationBuffer agg) throws HiveException;
public void aggregate(AggregationBuffer agg, Object[] parameters) throws HiveException {
if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
iterate(agg, parameters);
} else {
assert (parameters.length == 1);
merge(agg, parameters[0]);
}
}
public Object evaluate(AggregationBuffer agg) throws HiveException {
if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {
return terminatePartial(agg);
} else {
return terminate(agg);
}
}
public abstract void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException;
public abstract Object terminatePartial(AggregationBuffer agg) throws HiveException;
public abstract void merge(AggregationBuffer agg, Object partial) throws HiveException;
public abstract Object terminate(AggregationBuffer agg) throws HiveException;
首先是上面枚舉類型Mode的幾個枚舉值:PARTIAL1,PARTIAL2,FINAL,COMPLETE, 同時mode也是方法init的參數。這幾個枚舉值跟聚合涉及到的過程有關系, map-reduce中聚合往往涉及到shuffle的過程,這其中又可能涉及到map端的combine,然后map到reduce過程中數據的shuffle,然后在在reduce端merge。
下面這張圖大概的描述了一下各個階段的對應關系:
這張圖中沒有包含COMPLETE,從上面代碼中COMPLETE的注釋可以看出來,COMPLETE表示直接從原始數據聚合到最終結果,也就是說不存在中間需要先在map端完成部分聚合結果,然后再到reduce端完成最終聚合一個過程,COMPLETE出現在一個完全map only的任務中,所以沒有和其他三個階段一起出現。
上圖描述了三個階段調用的方法,這也就是需要自己實現的方法:
- PARTIAL1
- iterate(AggregationBuffer agg, Object[] parameters)
AggregationBuffer是一個需要你實現的數據結構,用來臨時保存聚合的數據,parameters是傳遞給udaf的實際參數,這個方法的功能可以描述成: 拿到一條條數據記錄方法在parameters里,然后聚合到agg中,怎么聚合自己實現,比如agg就是一個數組,你把所有迭代的數據保存到數組中都可以。agg相當于返回結果, - terminatePartial(AggregationBuffer agg)
iterate迭代了map中的數據并保存到agg中,并傳遞給terminatePartial,接下來terminatePartial完成計算,terminatePartial返回Object類型結果顯然還是要傳遞給下一個階段PARTIAL2的,但是PARTIAL2怎么知道Object到底是什么?前面提到HIVE都是通過ObjectInspector來獲取數據類型信息的,但是PARTIAL2的輸入數據ObjectInspector怎么來的?顯然每個階段輸出數據對應的ObjectInspector只有你自己知道,上面代碼中還有一個init()方法是需要你實現了(init在每一個階段都會調用一次 ),init的參數m表明了當前階段(當前處于PARTIAL1),你需要在init中根據當前階段m,設置一個ObjectInspector表示當前的輸出oi就行了,init返回一個ObjectInspcetor表示當前階段的輸出數據類信息(也就是下一階段的輸入數據信息)。
- iterate(AggregationBuffer agg, Object[] parameters)
- PARTIAL2
PARTIAL2的輸入是基于PARTIAL1的輸出的,PARTIAL1輸出即terminatePartial的返回值。- merge(AggregationBuffer agg, Object partial)
agg和partial1中的一樣,既是參數,也是返回值。partial就是partial1中terminatePartial的返回值,partial的具體數據信息需要你根據ObjectInspector獲取了。merger就表示把partial值先放到agg里,待會計算。 - terminatePartial
和partial1一樣。
- merge(AggregationBuffer agg, Object partial)
- FINAL
FINAL進入到reduce階段,也就是要完成最終結果的計算,和PARTIAL2不同的是它調用terminate,沒什么好說的,輸出最終結果而已。
關于init方法,方法原型:
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
這個方法會在每個階段都會調用一次,參數m表示當前調用的階段,parameters表示當前階段輸入數據的oi。前面提到partial1的terminatePartial的輸出就是partial2的輸入數據,那么此時partial1的輸出數據對應的oi,應該和partial2時調用init的參數parameters對應起來才能保存不出錯。
2.2.1 UDAF實例
這里實現的udaf的實例,他完成如下功能:
> SELECT col_concat(id, '<' , '>', ',' ) FROM person;
輸出:
<1,2,3,4,5,6>
udaf實現將某一個使用特定符號連接起來,并使用另外的字符包圍左右。
第一個參數就是列名,然后open,close,seperator
代碼如下:
public class GenericUDAFColConcat extends AbstractGenericUDAFResolver{
public GenericUDAFColConcat() {
}
/**
在getEvaluator中做一些類型檢查,
*/
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
// col_concat這個udaf需要接收4個參數
if(parameters.length != 4){
throw new UDFArgumentTypeException(parameters.length - 1,
"COL_CONCAT requires 4 argument, got " + parameters.length);
}
// 且只能用于連接一下PRIMITIVE類型的列
if(parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE){
throw new UDFArgumentTypeException(0,
"COL_CONCAT can only be used to concat PRIMITIVE type column, got " + parameters[0].getTypeName());
}
// 分隔符和包圍符,只能時char或者STRING
for(int i = 1; i < parameters.length; ++i){
if(parameters[i].getCategory() != ObjectInspector.Category.PRIMITIVE){
throw new UDFArgumentTypeException(i,
"COL_CONCAT only receive type CHAR/STRING as its 2nd to 4th argument's type, got " + parameters[i].getTypeName());
}
PrimitiveObjectInspector poi = (PrimitiveObjectInspector) TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[i]);
if(poi.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.CHAR &&
poi.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING){
throw new UDFArgumentTypeException(i,
"COL_CONCAT only receive type CHAR/STRING as its 2nd to 4th argument's type, got " + parameters[i].getTypeName());
}
}
// 返回自定義的XXXEvaluator
return new GenericUDAFCOLCONCATEvaluator();
}
// 前一節也說過需要實現AbstractAggregationBuffer用來保存聚合的值
private static class ColCollectAggregationBuffer extends GenericUDAFEvaluator.AbstractAggregationBuffer{
// 遍歷的列值暫時都方放到列表中保存起來。
private List<String> colValueList ;
private String open;
private String close;
private String seperator;
private boolean isInit;
public ColCollectAggregationBuffer() {
colValueList = new LinkedList<>();
this.isInit = false;
}
public void init(String open, String close, String seperator){
this.open = open;
this.close = close;
this.seperator = seperator;
this.isInit = true;
}
public boolean isInit(){
return isInit;
}
public String concat(){
String c = StringUtils.join(colValueList,seperator);
return open + c + close;
}
}
public static class GenericUDAFCOLCONCATEvaluator extends GenericUDAFEvaluator{
// transient避免序列化,因為這些成員其實都是在init中初始化了,沒有序列化的意義
// inputOIs用來保存PARTIAL1和COMPELE輸入數據的oi,這個各個階段都可能不一樣
private transient List<ObjectInspector> inputOIs = new LinkedList<>();
private transient Mode m;
private transient String pString;
// soi保存PARTIAL2和FINAL的輸入數據的oi
private transient StructObjectInspector soi;
private transient ListObjectInspector valueFieldOI;
private transient PrimitiveObjectInspector openFieldOI;
private transient PrimitiveObjectInspector closeFieldOI;
private transient PrimitiveObjectInspector seperatorFieldOI;
private transient StructField valueField;
private transient StructField openField;
private transient StructField closeField;
private transient StructField seperatorField;
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
// 父類的init必須調用
super.init(m,parameters);
this.m = m;
pString = "";
for(ObjectInspector p : parameters){
pString += p.getTypeName();
}
if(m == Mode.PARTIAL1 || m == Mode.COMPLETE){
// 在PARTIAL1和COMPLETE階段,輸入數據都是原始表中數據,而不是中間聚合數據,這里初始化inputOIs
inputOIs.clear();
for(ObjectInspector p : parameters){
inputOIs.add((PrimitiveObjectInspector)p);
}
}else {
// FINAL和PARTIAL2的輸入數據OI都是上一階段的輸出,而不是原始表數據,這里parameter[0]其實就是上一階段的輸出oi,具體情況看下面
soi = (StructObjectInspector)parameters[0];
valueField = soi.getStructFieldRef("values");
valueFieldOI = (ListObjectInspector)valueField.getFieldObjectInspector();
openField = soi.getStructFieldRef("open");
openFieldOI = (PrimitiveObjectInspector) openField.getFieldObjectInspector();
closeField = soi.getStructFieldRef("close");
closeFieldOI = (PrimitiveObjectInspector)closeField.getFieldObjectInspector();
seperatorField = soi.getStructFieldRef("seperator");
seperatorFieldOI = (PrimitiveObjectInspector)seperatorField.getFieldObjectInspector();
}
// 這里開始返回各個階段的輸出OI
if(m == Mode.PARTIAL1 || m == Mode.PARTIAL2){
// 后面的terminatePartial實現中,PARTIAL1 PARTIAL2的輸出數據都是一個列表,我把中間聚合和結果values, 以及open,close, seperator
// 按序方到列表中,所以這個地方返回的oi是一個StructObjectInspector的實現類,它能夠獲取list中的值。
ArrayList<ObjectInspector> foi = new ArrayList<>();
foi.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector));
foi.add(
PrimitiveObjectInspectorFactory.javaStringObjectInspector
);
foi.add(
PrimitiveObjectInspectorFactory.javaStringObjectInspector
);
foi.add(
PrimitiveObjectInspectorFactory.javaStringObjectInspector
);
ArrayList<String> fname = new ArrayList<String>();
fname.add("values");
fname.add("open");
fname.add("close");
fname.add("seperator");
return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
}else{
// COMPLETE和FINAL都是返回最終聚合結果了,也就是String,所以這里返回javaStringObjectInspector即可
return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
}
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
return new ColCollectAggregationBuffer();
}
@Override
public void reset(AggregationBuffer aggregationBuffer) throws HiveException {
((ColCollectAggregationBuffer)aggregationBuffer).colValueList.clear();
}
// PARTIAL1和COMPLETE調用,iterate里就是把原始數據(參數objects[0])中的值保存到aggregationBuffer的列表中
@Override
public void iterate(AggregationBuffer aggregationBuffer, Object[] objects) throws HiveException {
assert objects.length == 4;
ColCollectAggregationBuffer ccAggregationBuffer = (ColCollectAggregationBuffer)aggregationBuffer;
ccAggregationBuffer.colValueList.add(
PrimitiveObjectInspectorUtils.getString(objects[0], (PrimitiveObjectInspector)inputOIs.get(0)));
if(!ccAggregationBuffer.isInit()){
ccAggregationBuffer.init(
PrimitiveObjectInspectorUtils.getString(objects[1], (PrimitiveObjectInspector)inputOIs.get(1)),
PrimitiveObjectInspectorUtils.getString(objects[2],(PrimitiveObjectInspector)inputOIs.get(2)),
PrimitiveObjectInspectorUtils.getString(objects[3],(PrimitiveObjectInspector)inputOIs.get(3))
);
}
}
// PARTIAL1和PARTIAL2調用,沒做什么,但是返回的值的一個‘List<Object> partialRet’ 和init中返回的StructObjectInspector對應,
@Override
public Object terminatePartial(AggregationBuffer aggregationBuffer) throws HiveException {
ColCollectAggregationBuffer ccAggregationBuffer = (ColCollectAggregationBuffer)aggregationBuffer;
List<Object> partialRet = new ArrayList<>();
partialRet.add(ccAggregationBuffer.colValueList);
partialRet.add(ccAggregationBuffer.open);
partialRet.add(ccAggregationBuffer.close);
partialRet.add(ccAggregationBuffer.seperator);
return partialRet;
}
// PARTIAL2和FINAL調用,參數partial對應上面terminatePartial返回的列表,
@Override
public void merge(AggregationBuffer aggregationBuffer, Object partial) throws HiveException {
ColCollectAggregationBuffer ccAggregationBuffer = (ColCollectAggregationBuffer)aggregationBuffer;
if(partial != null){
// soi在init中初始化了,用它來獲取partial中數據。
List<Object> partialList = soi.getStructFieldsDataAsList(partial);
// terminalPartial中數據被保存在list中,這個地方拿出來只是簡單了合并兩個list,其他不變。
List<String> values = (List<String>)valueFieldOI.getList(partialList.get(0));
ccAggregationBuffer.colValueList.addAll(values);
if(!ccAggregationBuffer.isInit){
ccAggregationBuffer.open = PrimitiveObjectInspectorUtils.getString(partialList.get(1), openFieldOI);
ccAggregationBuffer.close = PrimitiveObjectInspectorUtils.getString(partialList.get(2), closeFieldOI);
ccAggregationBuffer.seperator = PrimitiveObjectInspectorUtils.getString(partialList.get(3), seperatorFieldOI);
}
}
}
// FINAL和COMPLETE調用,此時aggregationBuffer中用list保存了原始表表中一列的所有值,這里完成連接操作,返回一個string類型的連接結果。
@Override
public Object terminate(AggregationBuffer aggregationBuffer) throws HiveException {
return ((ColCollectAggregationBuffer)aggregationBuffer).concat();
}
}
}
2.3 自定義表生成函數
待完成。。。