前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >聊聊flink Table的AggregateFunction

聊聊flink Table的AggregateFunction

原创
作者头像
code4it
发布2019-02-09 11:09:08
2.8K0
发布2019-02-09 11:09:08
举报
文章被收录于专栏:码匠的流水账

本文主要研究一下flink Table的AggregateFunction

实例

代码语言:javascript
复制
/**
 * Accumulator for WeightedAvg.
 */
public static class WeightedAvgAccum {
    public long sum = 0;
    public int count = 0;
}
​
/**
 * Weighted Average user-defined aggregate function.
 */
public static class WeightedAvg extends AggregateFunction<Long, WeightedAvgAccum> {
​
    @Override
    public WeightedAvgAccum createAccumulator() {
        return new WeightedAvgAccum();
    }
​
    @Override
    public Long getValue(WeightedAvgAccum acc) {
        if (acc.count == 0) {
            return 0L;
        } else {
            return acc.sum / acc.count;
        }
    }
​
    public void accumulate(WeightedAvgAccum acc, long iValue, int iWeight) {
        acc.sum += iValue * iWeight;
        acc.count += iWeight;
    }
​
    public void retract(WeightedAvgAccum acc, long iValue, int iWeight) {
        acc.sum -= iValue * iWeight;
        acc.count -= iWeight;
    }
    
    public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> it) {
        Iterator<WeightedAvgAccum> iter = it.iterator();
        while (iter.hasNext()) {
            WeightedAvgAccum a = iter.next();
            acc.count += a.count;
            acc.sum += a.sum;
        }
    }
    
    public void resetAccumulator(WeightedAvgAccum acc) {
        acc.count = 0;
        acc.sum = 0L;
    }
}
​
// register function
BatchTableEnvironment tEnv = ...
tEnv.registerFunction("wAvg", new WeightedAvg());
​
// use function
tEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user");
  • WeightedAvg继承了AggregateFunction,实现了getValue、accumulate、retract、merge、resetAccumulator方法

AggregateFunction

flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/functions/AggregateFunction.scala

代码语言:javascript
复制
abstract class AggregateFunction[T, ACC] extends UserDefinedFunction {
  /**
    * Creates and init the Accumulator for this [[AggregateFunction]].
    *
    * @return the accumulator with the initial value
    */
  def createAccumulator(): ACC
​
  /**
    * Called every time when an aggregation result should be materialized.
    * The returned value could be either an early and incomplete result
    * (periodically emitted as data arrive) or the final result of the
    * aggregation.
    *
    * @param accumulator the accumulator which contains the current
    *                    aggregated results
    * @return the aggregation result
    */
  def getValue(accumulator: ACC): T
​
    /**
    * Returns true if this AggregateFunction can only be applied in an OVER window.
    *
    * @return true if the AggregateFunction requires an OVER window, false otherwise.
    */
  def requiresOver: Boolean = false
​
  /**
    * Returns the TypeInformation of the AggregateFunction's result.
    *
    * @return The TypeInformation of the AggregateFunction's result or null if the result type
    *         should be automatically inferred.
    */
  def getResultType: TypeInformation[T] = null
​
  /**
    * Returns the TypeInformation of the AggregateFunction's accumulator.
    *
    * @return The TypeInformation of the AggregateFunction's accumulator or null if the
    *         accumulator type should be automatically inferred.
    */
  def getAccumulatorType: TypeInformation[ACC] = null
}
  • AggregateFunction继承了UserDefinedFunction;它有两个泛型,一个T表示value的泛型,一个ACC表示Accumulator的泛型;它定义了createAccumulator、getValue、getResultType、getAccumulatorType方法(这几个方法中子类必须实现createAccumulator、getValue方法)
  • 对于AggregateFunction,有一个accumulate方法这里没定义,但是需要子类定义及实现,该方法接收ACC,T等参数,返回void;另外还有retract、merge、resetAccumulator三个方法是可选的,需要子类根据情况去定义及实现
  • 对于datastream bounded over aggregate操作,要求实现restract方法,该方法接收ACC,T等参数,返回void;对于datastream session window grouping aggregate以及dataset grouping aggregate操作,要求实现merge方法,该方法接收ACC,java.lang.Iterable<T>两个参数,返回void;对于dataset grouping aggregate操作,要求实现resetAccumulator方法,该方法接收ACC参数,返回void

DataSetPreAggFunction

flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala

代码语言:javascript
复制
class DataSetPreAggFunction(genAggregations: GeneratedAggregationsFunction)
  extends AbstractRichFunction
  with GroupCombineFunction[Row, Row]
  with MapPartitionFunction[Row, Row]
  with Compiler[GeneratedAggregations]
  with Logging {
​
  private var output: Row = _
  private var accumulators: Row = _
​
  private var function: GeneratedAggregations = _
​
  override def open(config: Configuration) {
    LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
                s"Code:\n$genAggregations.code")
    val clazz = compile(
      getRuntimeContext.getUserCodeClassLoader,
      genAggregations.name,
      genAggregations.code)
    LOG.debug("Instantiating AggregateHelper.")
    function = clazz.newInstance()
​
    output = function.createOutputRow()
    accumulators = function.createAccumulators()
  }
​
  override def combine(values: Iterable[Row], out: Collector[Row]): Unit = {
    // reset accumulators
    function.resetAccumulator(accumulators)
​
    val iterator = values.iterator()
​
    var record: Row = null
    while (iterator.hasNext) {
      record = iterator.next()
      // accumulate
      function.accumulate(accumulators, record)
    }
​
    // set group keys and accumulators to output
    function.setAggregationResults(accumulators, output)
    function.setForwardedFields(record, output)
​
    out.collect(output)
  }
​
  override def mapPartition(values: Iterable[Row], out: Collector[Row]): Unit = {
    combine(values, out)
  }
}
  • DataSetPreAggFunction的combine方法会调用function.accumulate(accumulators, record),其中accumulators为Row[WeightedAvgAccum]类型,record为Row类型;function为生成的类,它继承了GeneratedAggregations,其code在genAggregations中,而genAggregations则由AggregateUtil.createDataSetAggregateFunctions方法生成,它会去调用WeightedAvg的accumulate方法

GeneratedAggregations

flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala

代码语言:javascript
复制
abstract class GeneratedAggregations extends Function {
​
  /**
    * Setup method for [[org.apache.flink.table.functions.AggregateFunction]].
    * It can be used for initialization work. By default, this method does nothing.
    *
    * @param ctx The runtime context.
    */
  def open(ctx: RuntimeContext)
​
  /**
    * Sets the results of the aggregations (partial or final) to the output row.
    * Final results are computed with the aggregation function.
    * Partial results are the accumulators themselves.
    *
    * @param accumulators the accumulators (saved in a row) which contains the current
    *                     aggregated results
    * @param output       output results collected in a row
    */
  def setAggregationResults(accumulators: Row, output: Row)
​
  /**
    * Copies forwarded fields, such as grouping keys, from input row to output row.
    *
    * @param input        input values bundled in a row
    * @param output       output results collected in a row
    */
  def setForwardedFields(input: Row, output: Row)
​
  /**
    * Accumulates the input values to the accumulators.
    *
    * @param accumulators the accumulators (saved in a row) which contains the current
    *                     aggregated results
    * @param input        input values bundled in a row
    */
  def accumulate(accumulators: Row, input: Row)
​
  /**
    * Retracts the input values from the accumulators.
    *
    * @param accumulators the accumulators (saved in a row) which contains the current
    *                     aggregated results
    * @param input        input values bundled in a row
    */
  def retract(accumulators: Row, input: Row)
​
  /**
    * Initializes the accumulators and save them to a accumulators row.
    *
    * @return a row of accumulators which contains the aggregated results
    */
  def createAccumulators(): Row
​
  /**
    * Creates an output row object with the correct arity.
    *
    * @return an output row object with the correct arity.
    */
  def createOutputRow(): Row
​
  /**
    * Merges two rows of accumulators into one row.
    *
    * @param a First row of accumulators
    * @param b The other row of accumulators
    * @return A row with the merged accumulators of both input rows.
    */
  def mergeAccumulatorsPair(a: Row, b: Row): Row
​
  /**
    * Resets all the accumulators.
    *
    * @param accumulators the accumulators (saved in a row) which contains the current
    *                     aggregated results
    */
  def resetAccumulator(accumulators: Row)
​
  /**
    * Cleanup for the accumulators.
    */
  def cleanup()
​
  /**
    * Tear-down method for [[org.apache.flink.table.functions.AggregateFunction]].
    * It can be used for clean up work. By default, this method does nothing.
    */
  def close()
}
  • GeneratedAggregations定义了accumulate(accumulators: Row, input: Row)、resetAccumulator(accumulators: Row)等方法

AggregateUtil

flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala

代码语言:javascript
复制
object AggregateUtil {
​
  type CalcitePair[T, R] = org.apache.calcite.util.Pair[T, R]
  type JavaList[T] = java.util.List[T]
​
  //......
​
  /**
    * Create functions to compute a [[org.apache.flink.table.plan.nodes.dataset.DataSetAggregate]].
    * If all aggregation functions support pre-aggregation, a pre-aggregation function and the
    * respective output type are generated as well.
    */
  private[flink] def createDataSetAggregateFunctions(
      generator: AggregationCodeGenerator,
      namedAggregates: Seq[CalcitePair[AggregateCall, String]],
      inputType: RelDataType,
      inputFieldTypeInfo: Seq[TypeInformation[_]],
      outputType: RelDataType,
      groupings: Array[Int],
      tableConfig: TableConfig): (
        Option[DataSetPreAggFunction],
        Option[TypeInformation[Row]],
        Either[DataSetAggFunction, DataSetFinalAggFunction]) = {
​
    val needRetract = false
    val (aggInFields, aggregates, isDistinctAggs, accTypes, _) = transformToAggregateFunctions(
      namedAggregates.map(_.getKey),
      inputType,
      needRetract,
      tableConfig)
​
    val (gkeyOutMapping, aggOutMapping) = getOutputMappings(
      namedAggregates,
      groupings,
      inputType,
      outputType
    )
​
    val aggOutFields = aggOutMapping.map(_._1)
​
    if (doAllSupportPartialMerge(aggregates)) {
​
      // compute preaggregation type
      val preAggFieldTypes = gkeyOutMapping.map(_._2)
        .map(inputType.getFieldList.get(_).getType)
        .map(FlinkTypeFactory.toTypeInfo) ++ accTypes
      val preAggRowType = new RowTypeInfo(preAggFieldTypes: _*)
​
      val genPreAggFunction = generator.generateAggregations(
        "DataSetAggregatePrepareMapHelper",
        inputFieldTypeInfo,
        aggregates,
        aggInFields,
        aggregates.indices.map(_ + groupings.length).toArray,
        isDistinctAggs,
        isStateBackedDataViews = false,
        partialResults = true,
        groupings,
        None,
        groupings.length + aggregates.length,
        needRetract,
        needMerge = false,
        needReset = true,
        None
      )
​
      // compute mapping of forwarded grouping keys
      val gkeyMapping: Array[Int] = if (gkeyOutMapping.nonEmpty) {
        val gkeyOutFields = gkeyOutMapping.map(_._1)
        val mapping = Array.fill[Int](gkeyOutFields.max + 1)(-1)
        gkeyOutFields.zipWithIndex.foreach(m => mapping(m._1) = m._2)
        mapping
      } else {
        new Array[Int](0)
      }
​
      val genFinalAggFunction = generator.generateAggregations(
        "DataSetAggregateFinalHelper",
        inputFieldTypeInfo,
        aggregates,
        aggInFields,
        aggOutFields,
        isDistinctAggs,
        isStateBackedDataViews = false,
        partialResults = false,
        gkeyMapping,
        Some(aggregates.indices.map(_ + groupings.length).toArray),
        outputType.getFieldCount,
        needRetract,
        needMerge = true,
        needReset = true,
        None
      )
​
      (
        Some(new DataSetPreAggFunction(genPreAggFunction)),
        Some(preAggRowType),
        Right(new DataSetFinalAggFunction(genFinalAggFunction))
      )
    }
    else {
      val genFunction = generator.generateAggregations(
        "DataSetAggregateHelper",
        inputFieldTypeInfo,
        aggregates,
        aggInFields,
        aggOutFields,
        isDistinctAggs,
        isStateBackedDataViews = false,
        partialResults = false,
        groupings,
        None,
        outputType.getFieldCount,
        needRetract,
        needMerge = false,
        needReset = true,
        None
      )
​
      (
        None,
        None,
        Left(new DataSetAggFunction(genFunction))
      )
    }
​
  }
​
  //......
}
  • AggregateUtil的createDataSetAggregateFunctions方法主要是生成GeneratedAggregationsFunction,然后创建DataSetPreAggFunction或DataSetAggFunction;之所以动态生成code,主要是用户自定义的诸如accumulate方法的参数是动态的,而flink代码是基于GeneratedAggregations定义的accumulate(accumulators: Row, input: Row)方法来调用,因此动态生成的code用于适配,在accumulate(accumulators: Row, input: Row)方法里头将Row转换为调用用户定义的accumulate方法所需的参数,然后调用用户定义的accumulate方法

小结

  • AggregateFunction继承了UserDefinedFunction;它有两个泛型,一个T表示value的泛型,一个ACC表示Accumulator的泛型;它定义了createAccumulator、getValue、getResultType、getAccumulatorType方法(这几个方法中子类必须实现createAccumulator、getValue方法);对于AggregateFunction,有一个accumulate方法这里没定义,但是需要子类定义及实现,该方法接收ACC,T等参数,返回void;另外还有retract、merge、resetAccumulator三个方法是可选的,需要子类根据情况去定义及实现(对于datastream bounded over aggregate操作,要求实现restract方法,该方法接收ACC,T等参数,返回void;对于datastream session window grouping aggregate以及dataset grouping aggregate操作,要求实现merge方法,该方法接收ACC,java.lang.Iterable<T>两个参数,返回void;对于dataset grouping aggregate操作,要求实现resetAccumulator方法,该方法接收ACC参数,返回void)
  • DataSetPreAggFunction的combine方法会调用function.accumulate(accumulators, record),其中accumulators为Row[WeightedAvgAccum]类型,record为Row类型;function为生成的类,它继承了GeneratedAggregations,其code在genAggregations中,而genAggregations则由AggregateUtil.createDataSetAggregateFunctions方法生成,它会去调用WeightedAvg的accumulate方法;GeneratedAggregations定义了accumulate(accumulators: Row, input: Row)、resetAccumulator(accumulators: Row)等方法
  • AggregateUtil的createDataSetAggregateFunctions方法主要是生成GeneratedAggregationsFunction,然后创建DataSetPreAggFunction或DataSetAggFunction;之所以动态生成code,主要是用户自定义的诸如accumulate方法的参数是动态的,而flink代码是基于GeneratedAggregations定义的accumulate(accumulators: Row, input: Row)方法来调用,因此动态生成的code用于适配,在accumulate(accumulators: Row, input: Row)方法里头将Row转换为调用用户定义的accumulate方法所需的参数,然后调用用户定义的accumulate方法

doc

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 实例
  • AggregateFunction
  • DataSetPreAggFunction
  • GeneratedAggregations
  • AggregateUtil
  • 小结
  • doc
相关产品与服务
大数据
全栈大数据产品,面向海量数据场景,帮助您 “智理无数,心中有数”!
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档