前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【转】hive udaf函数求中位数

【转】hive udaf函数求中位数

作者头像
yiduwangkai
发布2019-09-17 15:58:46
1.4K0
发布2019-09-17 15:58:46
举报

第一次写UDAF,拿中位数来练手。

看下中位数定义:

MEDIAN 中位数(一组数据按从小到大的顺序依次排列,处在中间位置的一个数或最中间两个数据的平均数) 写成genericUDAF的形式

1 2 3 4	中位数 2+3/2=2.5
1 2 3	中位数 2

代码如下

package org.apache.hadoop.hive.ql.udf.generic;
 
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
 
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.util.StringUtils;
 
 
@Description(name="median",value=""
		+ "_FUNC_(x) return the median number of a number array. eg: median(x)")
public class GenericUDAFMedian extends AbstractGenericUDAFResolver {
 
	static final Log LOG = LogFactory.getLog(GenericUDAFMedian.class.getName());
	
	@Override
	public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
			throws SemanticException {
		if(parameters.length != 1) {
			throw new UDFArgumentTypeException(parameters.length-1, "Only 1 parameter is accepted!");
		}
		
		ObjectInspector objectInspector = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[0]);
		if(!ObjectInspectorUtils.compareSupported(objectInspector)) {
			throw new UDFArgumentTypeException(parameters.length - 1, "Cannot support comparison of map<> type or complex type containing map<>.");
		}
		
		switch (((PrimitiveTypeInfo)parameters[0]).getPrimitiveCategory()) {
		case BYTE:
	    case SHORT:
	    case INT:
	    	return new GenericUDAFMedianEvaluatorInt();
	    case LONG:
	    	return new GenericUDAFMedianEvaluatorLong();
	    case FLOAT:
	    case DOUBLE:
	    	return new GenericUDAFMedianEvaluatorDouble();
	    case STRING:
	    case BOOLEAN:
	    default:
	      throw new UDFArgumentTypeException(0,
	          "Only numeric type(int long double) arguments are accepted but "
	          + parameters[0].getTypeName() + " was passed as parameter of index->1.");
		}
	}
	
	public static class GenericUDAFMedianEvaluatorInt extends GenericUDAFEvaluator {
		
		private DoubleWritable result = new DoubleWritable() ;
		PrimitiveObjectInspector inputOI;
		StructObjectInspector structOI;
		StandardListObjectInspector listOI;
		StructField listField;
		Object[] partialResult;  
		ListObjectInspector listFieldOI;
		 
		@Override
		public ObjectInspector init(Mode m, ObjectInspector[] parameters)
				throws HiveException {
             assert (parameters.length == 1);
			 super.init(m, parameters);
			 
			 listOI = ObjectInspectorFactory.getStandardListObjectInspector(
		    		  PrimitiveObjectInspectorFactory.writableIntObjectInspector);
			 //init input
			 if(m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
				 inputOI = (PrimitiveObjectInspector) parameters[0];
			 }
			 else {
				 structOI = (StructObjectInspector) parameters[0];
				 listField = structOI.getStructFieldRef("list");
				 listFieldOI = (ListObjectInspector) listField.getFieldObjectInspector();
			}
			
			//init output
			 if(m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
				 ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
				 foi.add(listOI);
				 ArrayList<String> fname = new ArrayList<String>();
		    	 fname.add("list");
		    	 partialResult = new Object[1];
		    	 partialResult[0] = new ArrayList<IntWritable>();
		    	 return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
			 }else {
				return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
			}
			
		}
 
		static class MedianNumberAgg implements AggregationBuffer {
			List<IntWritable> aggIntegerList;
		}
		
		@Override
		public AggregationBuffer getNewAggregationBuffer() throws HiveException {
			MedianNumberAgg resultAgg = new MedianNumberAgg();
			reset(resultAgg);
			return resultAgg;
		}
 
		@Override
		public void reset(AggregationBuffer agg) throws HiveException {
			MedianNumberAgg medianNumberAgg = (MedianNumberAgg)agg;
			medianNumberAgg.aggIntegerList = null;
			medianNumberAgg.aggIntegerList = new ArrayList<IntWritable>();
		}
 
		 boolean warned = false;
		 
		@Override
		public void iterate(AggregationBuffer agg, Object[] parameters)
				throws HiveException {
			assert(parameters.length == 1);
			if(parameters[0] != null) {
				MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
				int val = 0;
				try {
					 val = PrimitiveObjectInspectorUtils.getInt(parameters[0], (PrimitiveObjectInspector)inputOI);
				} catch (NullPointerException e) {
					LOG.warn("got a null value, skip it");
				}catch (NumberFormatException e) {
					if(!warned) {
						warned = true;
						LOG.warn(getClass().getSimpleName() + " " + StringUtils.stringifyException(e));
						LOG.warn("ignore similar exceptions.");
					}
					
				}
				medianNumberAgg.aggIntegerList.add(new IntWritable(val));
			}
		}
 
		@SuppressWarnings("unchecked")
		@Override
		public Object terminate(AggregationBuffer agg) throws HiveException {
			MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
			Collections.sort(medianNumberAgg.aggIntegerList);
			int size = medianNumberAgg.aggIntegerList.size();
			if(size == 1) {
				result.set((double)medianNumberAgg.aggIntegerList.get(0).get());
				return result;
			}
			double rs = 0.0;
//			int midIndex = (int) Math.floor(((double) size / 2));
			int midIndex = size / 2;
			if(size%2 == 1) {
				rs = (double) medianNumberAgg.aggIntegerList.get(midIndex).get();
			}
			else if(size%2 == 0) {
				rs =( medianNumberAgg.aggIntegerList.get(midIndex - 1).get() + medianNumberAgg.aggIntegerList.get(midIndex).get() ) / 2.0 ;
			}
			result.set(rs);
			return result;
		}
 
		@Override
		public Object terminatePartial(AggregationBuffer agg)
				throws HiveException {
			MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
			partialResult[0] = new ArrayList<IntWritable>(medianNumberAgg.aggIntegerList.size());
			((ArrayList<IntWritable>) partialResult[0]).addAll(	medianNumberAgg.aggIntegerList);
			return partialResult;
		}
 
		@Override
		public void merge(AggregationBuffer agg, Object partial)
				throws HiveException {
			MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
			Object partialObject = structOI.getStructFieldData(partial, listField);
			ArrayList<IntWritable> resultList = (ArrayList<IntWritable>) listFieldOI.getList(partialObject);
			for( IntWritable  i : resultList) {
				medianNumberAgg.aggIntegerList.add(i);
			}
		}
		
	}
 
 
public static class GenericUDAFMedianEvaluatorDouble extends GenericUDAFEvaluator {
		
		private DoubleWritable result = new DoubleWritable() ;
		PrimitiveObjectInspector inputOI;
		StructObjectInspector structOI;
		StandardListObjectInspector listOI;
		StructField listField;
		Object[] partialResult;  
		ListObjectInspector listFieldOI;
		 
		@Override
		public ObjectInspector init(Mode m, ObjectInspector[] parameters)
				throws HiveException {
             assert (parameters.length == 1);
			 super.init(m, parameters);
			 
			 listOI = ObjectInspectorFactory.getStandardListObjectInspector(
		    		  PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
			 //init input
			 if(m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
				 inputOI = (PrimitiveObjectInspector) parameters[0];
			 }
			 else {
				 structOI = (StructObjectInspector) parameters[0];
				 listField = structOI.getStructFieldRef("list");
				 listFieldOI = (ListObjectInspector) listField.getFieldObjectInspector();
			}
			
			//init output
			 if(m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
				 ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
				 foi.add(listOI);
				 ArrayList<String> fname = new ArrayList<String>();
		    	 fname.add("list");
		    	 partialResult = new Object[1];
		    	 return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
			 }else {
				return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
			}
			
		}
 
		static class MedianNumberAgg implements AggregationBuffer {
			List<DoubleWritable> aggIntegerList;
		}
		
		@Override
		public AggregationBuffer getNewAggregationBuffer() throws HiveException {
			MedianNumberAgg resultAgg = new MedianNumberAgg();
			reset(resultAgg);
			return resultAgg;
		}
 
		@Override
		public void reset(AggregationBuffer agg) throws HiveException {
			MedianNumberAgg medianNumberAgg = (MedianNumberAgg)agg;
			medianNumberAgg.aggIntegerList = null;
			medianNumberAgg.aggIntegerList = new ArrayList<DoubleWritable>();
		}
 
		 boolean warned = false;
		 
		@Override
		public void iterate(AggregationBuffer agg, Object[] parameters)
				throws HiveException {
			assert(parameters.length == 1);
			if(parameters[0] != null) {
				MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
				double val = 0.0;
				try {
					 val = PrimitiveObjectInspectorUtils.getDouble(parameters[0], (PrimitiveObjectInspector)inputOI);
				} catch (NullPointerException e) {
					LOG.warn("got a null value, skip it");
				}catch (NumberFormatException e) {
					if(!warned) {
						warned = true;
						LOG.warn(getClass().getSimpleName() + " " + StringUtils.stringifyException(e));
						LOG.warn("ignore similar exceptions.");
					}
					
				}
				medianNumberAgg.aggIntegerList.add(new DoubleWritable(val));
			}
		}
 
		@SuppressWarnings("unchecked")
		@Override
		public Object terminate(AggregationBuffer agg) throws HiveException {
			MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
			Collections.sort(medianNumberAgg.aggIntegerList);
			int size = medianNumberAgg.aggIntegerList.size();
			if(size == 1) {
				result.set((double)medianNumberAgg.aggIntegerList.get(0).get());
				return result;
			}
			double rs = 0.0;
//			int midIndex = (int) Math.floor(((double) size / 2));
			int midIndex = size / 2;
			if(size%2 == 1) {
				rs = (double) medianNumberAgg.aggIntegerList.get(midIndex).get();
			}
			else if(size%2 == 0) {
				rs =( medianNumberAgg.aggIntegerList.get(midIndex - 1).get() + medianNumberAgg.aggIntegerList.get(midIndex).get() ) / 2.0 ;
			}
			result.set(rs);
			return result;
		}
 
		@Override
		public Object terminatePartial(AggregationBuffer agg)
				throws HiveException {
			MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
			partialResult[0] = new ArrayList<DoubleWritable>(medianNumberAgg.aggIntegerList.size());
			((ArrayList<DoubleWritable>) partialResult[0]).addAll(medianNumberAgg.aggIntegerList);
			return partialResult;
		}
 
		@Override
		public void merge(AggregationBuffer agg, Object partial)
				throws HiveException {
			MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
			Object partialObject = structOI.getStructFieldData(partial, listField);
			ArrayList<DoubleWritable> resultList = (ArrayList<DoubleWritable>) listFieldOI.getList(partialObject);
			for( DoubleWritable  i : resultList) {
				medianNumberAgg.aggIntegerList.add(i);
			}
		}
		
	}
	
 
public static class GenericUDAFMedianEvaluatorLong extends GenericUDAFEvaluator {
	
	private DoubleWritable result = new DoubleWritable() ;
	PrimitiveObjectInspector inputOI;
	StructObjectInspector structOI;
	StandardListObjectInspector listOI;
	StructField listField;
	Object[] partialResult;  
	ListObjectInspector listFieldOI;
	 
	@Override
	public ObjectInspector init(Mode m, ObjectInspector[] parameters)
			throws HiveException {
         assert (parameters.length == 1);
		 super.init(m, parameters);
		 
		 listOI = ObjectInspectorFactory.getStandardListObjectInspector(
	    		  PrimitiveObjectInspectorFactory.writableLongObjectInspector);
		 //init input
		 if(m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
			 inputOI = (PrimitiveObjectInspector) parameters[0];
		 }
		 else {
			 structOI = (StructObjectInspector) parameters[0];
			 listField = structOI.getStructFieldRef("list");
			 listFieldOI = (ListObjectInspector) listField.getFieldObjectInspector();
		}
		
		//init output
		 if(m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
			 ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
			 foi.add(listOI);
			 ArrayList<String> fname = new ArrayList<String>();
	    	 fname.add("list");
	    	 partialResult = new Object[1];
	    	 partialResult[0] = new ArrayList<LongWritable>();
	    	 return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
		 }else {
			return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
		}
		
	}
 
	static class MedianNumberAgg implements AggregationBuffer {
		List<LongWritable> aggIntegerList;
	}
	
	@Override
	public AggregationBuffer getNewAggregationBuffer() throws HiveException {
		MedianNumberAgg resultAgg = new MedianNumberAgg();
		reset(resultAgg);
		return resultAgg;
	}
 
	@Override
	public void reset(AggregationBuffer agg) throws HiveException {
		MedianNumberAgg medianNumberAgg = (MedianNumberAgg)agg;
		medianNumberAgg.aggIntegerList = null;
		medianNumberAgg.aggIntegerList = new ArrayList<LongWritable>();
	}
 
	 boolean warned = false;
	 
	@Override
	public void iterate(AggregationBuffer agg, Object[] parameters)
			throws HiveException {
		assert(parameters.length == 1);
		if(parameters[0] != null) {
			MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
			long val = 0L;
			try {
				 val = PrimitiveObjectInspectorUtils.getLong(parameters[0], (PrimitiveObjectInspector)inputOI);
			} catch (NullPointerException e) {
				LOG.warn("got a null value, skip it");
			}catch (NumberFormatException e) {
				if(!warned) {
					warned = true;
					LOG.warn(getClass().getSimpleName() + " " + StringUtils.stringifyException(e));
					LOG.warn("ignore similar exceptions.");
				}
				
			}
			medianNumberAgg.aggIntegerList.add(new LongWritable(val));
		}
	}
 
	@SuppressWarnings("unchecked")
	@Override
	public Object terminate(AggregationBuffer agg) throws HiveException {
		MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
		Collections.sort(medianNumberAgg.aggIntegerList);
		int size = medianNumberAgg.aggIntegerList.size();
		if(size == 1) {
			result.set((double)medianNumberAgg.aggIntegerList.get(0).get());
			return result;
		}
		double rs = 0.0;
//		int midIndex = (int) Math.floor(((double) size / 2));
		int midIndex = size / 2;
		if(size%2 == 1) {
			rs = (double) medianNumberAgg.aggIntegerList.get(midIndex).get();
		}
		else if(size%2 == 0) {
			rs =( medianNumberAgg.aggIntegerList.get(midIndex - 1).get() + medianNumberAgg.aggIntegerList.get(midIndex).get() ) / 2.0 ;
		}
		result.set(rs);
		return result;
	}
 
	@Override
	public Object terminatePartial(AggregationBuffer agg)
			throws HiveException {
		MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
		partialResult[0] = new ArrayList<LongWritable>(medianNumberAgg.aggIntegerList.size());
		((ArrayList<LongWritable>) partialResult[0]).addAll(medianNumberAgg.aggIntegerList);
		return partialResult;
	}
 
	@Override
	public void merge(AggregationBuffer agg, Object partial)
			throws HiveException {
		MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
		Object partialObject = structOI.getStructFieldData(partial, listField);
		ArrayList<LongWritable> resultList = (ArrayList<LongWritable>) listFieldOI.getList(partialObject);
		for( LongWritable  i : resultList) {
			medianNumberAgg.aggIntegerList.add(i);
		}
	}
	
}
	
}

测试:

use datawarehouse;
add jar /home/hadoop/shengli/median.jar;
create temporary function median as 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMedian';
select median(id) from
(
select 7 id from dual
union all
select 8 id from dual
union all
select 1 id from dual
 
) a;
 
 
select median(id) from
(
select cast(1 as bigint) id from dual
union all 
select cast(2 as bigint) id from dual
) a
 
 
select median(id) from
(
select 1.0 id from dual
union all 
select 2.3 id from dual
) a
 
select median(id) from
(
select 1 id from dual
union all
select 2 id from dual
union all
select 3 id from dual
) a
 
 
select median(id) from
(
 
select null id from dual
) a
---------------------------------
select type,median(id) from
(
select 'a' type,3 id from dual
union all
select 'a' type,-2 id from dual
union all
select 'a' type,1 id from dual
union all
select 'a' type,4 id from dual
union all
select 'b' type,6 id from dual
union all
select 'b' type,5 id from dual
union all
select 'b' type,4 id from dual
) a
group by type
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
腾讯云服务器利旧
云服务器(Cloud Virtual Machine,CVM)提供安全可靠的弹性计算服务。 您可以实时扩展或缩减计算资源,适应变化的业务需求,并只需按实际使用的资源计费。使用 CVM 可以极大降低您的软硬件采购成本,简化 IT 运维工作。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档